Repository: baidu/unit-dmkit
Branch: master
Commit: 6b837ee07504
Files: 97
Total size: 849.3 KB
Directory structure:
gitextract_m7sg05uu/
├── .gitignore
├── CMakeLists.txt
├── Dockerfile
├── LICENSE
├── NOTICE
├── README.md
├── conf/
│ ├── app/
│ │ ├── bot_tokens.json
│ │ ├── demo/
│ │ │ ├── book_hotel.json
│ │ │ ├── book_hotel.xml
│ │ │ ├── cellular_data.json
│ │ │ ├── cellular_data.xml
│ │ │ ├── quota_adjust.json
│ │ │ └── quota_adjust.xml
│ │ ├── products.json
│ │ └── remote_services.json
│ └── gflags.conf
├── deps.sh
├── docs/
│ ├── demo_book_hotel_pattern.txt
│ ├── demo_cellular_data_pattern.txt
│ ├── demo_quota_adjust_pattern.txt
│ ├── demo_skills.md
│ ├── faq.md
│ ├── tutorial.md
│ └── visual_tool.md
├── language_compiler/
│ ├── compiler_xml.py
│ ├── run.py
│ └── settings.cfg
├── proto/
│ └── http.proto
├── src/
│ ├── app_container.cpp
│ ├── app_container.h
│ ├── app_log.h
│ ├── application_base.h
│ ├── brpc.h
│ ├── butil.h
│ ├── dialog_manager.cpp
│ ├── dialog_manager.h
│ ├── file_watcher.cpp
│ ├── file_watcher.h
│ ├── policy.cpp
│ ├── policy.h
│ ├── policy_manager.cpp
│ ├── policy_manager.h
│ ├── qu_result.cpp
│ ├── qu_result.h
│ ├── rapidjson.h
│ ├── remote_service_manager.cpp
│ ├── remote_service_manager.h
│ ├── request_context.cpp
│ ├── request_context.h
│ ├── server.cpp
│ ├── thirdparty/
│ │ └── rapidjson/
│ │ ├── allocators.h
│ │ ├── document.h
│ │ ├── encodedstream.h
│ │ ├── encodings.h
│ │ ├── error/
│ │ │ ├── en.h
│ │ │ └── error.h
│ │ ├── filereadstream.h
│ │ ├── filewritestream.h
│ │ ├── fwd.h
│ │ ├── internal/
│ │ │ ├── biginteger.h
│ │ │ ├── diyfp.h
│ │ │ ├── dtoa.h
│ │ │ ├── ieee754.h
│ │ │ ├── itoa.h
│ │ │ ├── meta.h
│ │ │ ├── pow10.h
│ │ │ ├── regex.h
│ │ │ ├── stack.h
│ │ │ ├── strfunc.h
│ │ │ ├── strtod.h
│ │ │ └── swap.h
│ │ ├── istreamwrapper.h
│ │ ├── memorybuffer.h
│ │ ├── memorystream.h
│ │ ├── msinttypes/
│ │ │ ├── inttypes.h
│ │ │ └── stdint.h
│ │ ├── ostreamwrapper.h
│ │ ├── pointer.h
│ │ ├── prettywriter.h
│ │ ├── rapidjson.h
│ │ ├── reader.h
│ │ ├── schema.h
│ │ ├── stream.h
│ │ ├── stringbuffer.h
│ │ └── writer.h
│ ├── thread_data_base.h
│ ├── token_manager.cpp
│ ├── token_manager.h
│ ├── user_function/
│ │ ├── demo.cpp
│ │ ├── demo.h
│ │ ├── shared.cpp
│ │ └── shared.h
│ ├── user_function_manager.cpp
│ ├── user_function_manager.h
│ └── utils.h
└── tools/
├── bot_emulator.py
└── mock_api_server.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
/output
/Makefile
/brpc
/_build
/build
/.vscode
*.pyc
.DS_Store
================================================
FILE: CMakeLists.txt
================================================
cmake_minimum_required(VERSION 2.8.10)
project(dmkit C CXX)
execute_process(
COMMAND bash -c "find ${CMAKE_SOURCE_DIR}/brpc -type d -regex \".*output/include$\" | xargs dirname | tr -d '\n'"
OUTPUT_VARIABLE OUTPUT_PATH
)
set(CMAKE_PREFIX_PATH ${OUTPUT_PATH})
include(FindThreads)
include(FindProtobuf)
protobuf_generate_cpp(PROTO_SRC PROTO_HEADER proto/http.proto)
# include PROTO_HEADER
include_directories(${CMAKE_CURRENT_BINARY_DIR})
find_path(BRPC_INCLUDE_PATH NAMES brpc/server.h)
find_library(BRPC_LIB NAMES libbrpc.a brpc)
if((NOT BRPC_INCLUDE_PATH) OR (NOT BRPC_LIB))
message(FATAL_ERROR "Fail to find brpc")
endif()
include_directories(${BRPC_INCLUDE_PATH})
find_path(GFLAGS_INCLUDE_PATH gflags/gflags.h)
find_library(GFLAGS_LIBRARY NAMES gflags libgflags)
if((NOT GFLAGS_INCLUDE_PATH) OR (NOT GFLAGS_LIBRARY))
message(FATAL_ERROR "Fail to find gflags")
endif()
include_directories(${GFLAGS_INCLUDE_PATH})
execute_process(
COMMAND bash -c "grep \"namespace [_A-Za-z0-9]\\+ {\" ${GFLAGS_INCLUDE_PATH}/gflags/gflags_declare.h | head -1 | awk '{print $2}' | tr -d '\n'"
OUTPUT_VARIABLE GFLAGS_NS
)
if(${GFLAGS_NS} STREQUAL "GFLAGS_NAMESPACE")
execute_process(
COMMAND bash -c "grep \"#define GFLAGS_NAMESPACE [_A-Za-z0-9]\\+\" ${GFLAGS_INCLUDE_PATH}/gflags/gflags_declare.h | head -1 | awk '{print $3}' | tr -d '\n'"
OUTPUT_VARIABLE GFLAGS_NS
)
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
include(CheckFunctionExists)
CHECK_FUNCTION_EXISTS(clock_gettime HAVE_CLOCK_GETTIME)
if(NOT HAVE_CLOCK_GETTIME)
set(DEFINE_CLOCK_GETTIME "-DNO_CLOCK_GETTIME_IN_MAC")
endif()
endif()
set(CMAKE_CPP_FLAGS "${DEFINE_CLOCK_GETTIME} ${CMAKE_CPP_FLAGS} -DGFLAGS_NS=${GFLAGS_NS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CPP_FLAGS} -DNDEBUG -O2 -D__const__= -pipe -W -Wall -Wno-unused-parameter -fPIC -fno-omit-frame-pointer")
if(CMAKE_VERSION VERSION_LESS "3.1.3")
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
endif()
else()
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
endif()
find_path(LEVELDB_INCLUDE_PATH NAMES leveldb/db.h)
find_library(LEVELDB_LIB NAMES leveldb)
if ((NOT LEVELDB_INCLUDE_PATH) OR (NOT LEVELDB_LIB))
message(FATAL_ERROR "Fail to find leveldb")
endif()
include_directories(${LEVELDB_INCLUDE_PATH})
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
set(OPENSSL_ROOT_DIR
"/usr/local/opt/openssl" # Homebrew installed OpenSSL
)
endif()
include(FindOpenSSL)
find_library(CURL_LIB NAMES curl)
if (NOT CURL_LIB)
message(FATAL_ERROR "Fail to find curl")
endif()
set(DYNAMIC_LIB
${CMAKE_THREAD_LIBS_INIT}
${GFLAGS_LIBRARY}
${PROTOBUF_LIBRARIES}
${LEVELDB_LIB}
${OPENSSL_LIBRARIES}
${OPENSSL_CRYPTO_LIBRARY}
${CURL_LIB}
dl
)
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
set(DYNAMIC_LIB ${DYNAMIC_LIB}
pthread
"-framework CoreFoundation"
"-framework CoreGraphics"
"-framework CoreData"
"-framework CoreText"
"-framework Security"
"-framework Foundation"
"-Wl,-U,_MallocExtension_ReleaseFreeMemory"
"-Wl,-U,_ProfilerStart"
"-Wl,-U,_ProfilerStop"
"-Wl,-U,_RegisterThriftProtocol")
endif()
file(GLOB DMKIT_SRC
"src/*.cpp"
"src/*/*.cpp"
)
include_directories(
${CMAKE_SOURCE_DIR}/src
)
add_executable(dmkit ${DMKIT_SRC} ${PROTO_SRC} ${PROTO_HEADER})
target_link_libraries(dmkit ${BRPC_LIB} ${DYNAMIC_LIB})
add_custom_command(
TARGET dmkit POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
${CMAKE_SOURCE_DIR}/conf
${CMAKE_BINARY_DIR}/conf
)
================================================
FILE: Dockerfile
================================================
FROM ubuntu:18.04
WORKDIR /unit-dmkit
COPY . /unit-dmkit
RUN apt-get update && apt-get install -y --no-install-recommends sudo cmake wget vim curl ca-certificates
RUN update-ca-certificates
RUN sh deps.sh ubuntu
RUN rm -rf _build && mkdir _build && cd _build && cmake .. && make -j8
EXPOSE 8010
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: NOTICE
================================================
Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# DMKit
DMKit作为UNIT的开源对话管理模块,可以无缝对接UNIT的理解能力,并赋予开发者多状态的复杂对话流程管理能力,还可以低成本对接外部知识库,迅速丰富话术信息量。
## 快速开始
### 编译DMKit
DMKit基于[brpc](https://github.com/brpc/brpc)开发并提供HTTP服务,支持MacOS,Ubuntu,Centos等系统环境,推荐使用Ubuntu 16.04或CentOS 7。在编译DMKit之前,需要先安装依赖并下载编译brpc:
```bash
sh deps.sh [OS]
```
其中[OS]参数指定系统类型用于安装对应系统依赖,支持取值包括ubuntu、mac、centos。如果已手动安装依赖,则传入none。
使用cmake编译DMKit:
```bash
mkdir _build && cd _build && cmake .. && make
```
### 运行示例技能
DMKit提供了示例场景技能,在运行示例技能之前,需要在UNIT平台配置实现技能的理解能力:[示例场景](docs/demo_skills.md)
根据UNIT平台创建的skill id修改编译产出_build目录下的conf/app/products.json文件,在其中配置所创建skill id与对应场景DMKit配置文件。例如,查询流量及续订场景,在UNIT平台创建skill id为12345,则对应的配置文件内容应为:
```JSON
{
"default": {
"12345": {
"score": 1,
"conf_path": "conf/app/demo/cellular_data.json"
}
}
}
```
在_build目录下运行DMKit:
```bash
./dmkit
```
可以通过tools目录下的bot_emulator.py程序模拟与技能进行交互,使用方法为:
```bash
python bot_emulator.py [skill id] [access token]
```
### 更多文档
* [DMKit快速上手](docs/tutorial.md)
* [可视化配置工具](docs/visual_tool.md)
* [常见问题](docs/faq.md)
### 多语言支持
* PHP:[PHP版本官方代码库](https://github.com/baidu/dm-kit-php)
## 如何贡献
* 提交issue可以是新需求也可以是bug,也可以是对某一个问题的讨论。
* 对于issues中的问题欢迎贡献并发起pull request。
## 讨论
* 提issue发起问题讨论,如果是问题选择类型为问题即可。
* 欢迎加入UNIT QQ群(584835350)交流讨论。
================================================
FILE: conf/app/bot_tokens.json
================================================
{
"1234": {
"api_key": "",
"secret_key": ""
}
}
================================================
FILE: conf/app/demo/book_hotel.json
================================================
[
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "请问您要预订哪里的酒店"
}
],
"session": {
"context": {},
"state": "001"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_BOOK_HOTEL",
"slots": [
"user_time",
"user_room_type"
],
"state": ""
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "{%location%}附近的酒店有{%hotel_option%},请问你要预订哪一个?"
}
],
"session": {
"context": {},
"state": "003"
}
}
],
"params": [
{
"name": "location",
"type": "slot_val",
"value": "user_location"
},
{
"name": "hotel_option",
"type": "func_val",
"value": "service_http_get:hotel_service,/hotel/search?location={%location%}"
}
],
"trigger": {
"intent": "INTENT_BOOK_HOTEL",
"slots": [
"user_time",
"user_room_type",
"user_location"
],
"state": ""
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "{%location%}附近的酒店有{%hotel_option%},请问你要预订哪一个?"
}
],
"session": {
"context": {},
"state": "003"
}
}
],
"params": [
{
"name": "location",
"type": "slot_val",
"value": "user_location"
},
{
"name": "hotel_option",
"type": "func_val",
"value": "service_http_get:hotel_service,/hotel/search?location={%location%}"
}
],
"trigger": {
"intent": "INTENT_BOOK_HOTEL",
"slots": [
"user_time",
"user_room_type",
"user_location"
],
"state": "001"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "您选择了{%time%}{%hotel%}的{%room_type%},是否确认预订?"
}
],
"session": {
"context": {},
"state": "002"
}
}
],
"params": [
{
"name": "time",
"type": "slot_val",
"value": "user_time"
},
{
"name": "hotel",
"type": "slot_val",
"value": "user_hotel"
},
{
"name": "room_type",
"type": "slot_val",
"value": "user_room_type"
}
],
"trigger": {
"intent": "INTENT_BOOK_HOTEL",
"slots": [
"user_time",
"user_room_type",
"user_hotel"
],
"state": "001"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "您选择了{%time%}{%hotel%}的{%room_type%},是否确认预订?"
}
],
"session": {
"context": {},
"state": "002"
}
}
],
"params": [
{
"name": "time",
"type": "slot_val",
"value": "user_time"
},
{
"name": "hotel",
"type": "slot_val",
"value": "user_hotel"
},
{
"name": "room_type",
"type": "slot_val",
"value": "user_room_type"
}
],
"trigger": {
"intent": "INTENT_BOOK_HOTEL",
"slots": [
"user_time",
"user_room_type",
"user_hotel"
],
"state": ""
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "已为您预订{%time%}{%hotel%}的{%room_type%}"
}
],
"session": {
"context": {},
"state": "004"
}
}
],
"params": [
{
"name": "time",
"type": "slot_val",
"value": "user_time"
},
{
"name": "hotel",
"type": "slot_val",
"value": "user_hotel"
},
{
"name": "room_type",
"type": "slot_val",
"value": "user_room_type"
},
{
"name": "hotel_option",
"type": "func_val",
"value": "service_http_get:hotel_service,/hotel/book?time={%time%}&hotel={%hotel%}&room_type={%room_type%}"
}
],
"trigger": {
"intent": "INTENT_YES",
"slots": [],
"state": "002"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "请告诉我您的新需求"
}
],
"session": {
"context": {},
"state": "005"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_NO",
"slots": [],
"state": "002"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "您选择了{%time%}{%hotel%}的{%room_type%},是否确认预订?"
}
],
"session": {
"context": {},
"state": "002"
}
}
],
"params": [
{
"name": "time",
"type": "slot_val",
"value": "user_time"
},
{
"name": "hotel",
"type": "slot_val",
"value": "user_hotel"
},
{
"name": "room_type",
"type": "slot_val",
"value": "user_room_type"
}
],
"trigger": {
"intent": "INTENT_BOOK_HOTEL",
"slots": [
"user_time",
"user_room_type",
"user_location",
"user_hotel"
],
"state": ""
}
}
]
================================================
FILE: conf/app/demo/book_hotel.xml
================================================
================================================
FILE: conf/app/demo/cellular_data.json
================================================
[
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "您想查询几月份的流量?"
}
],
"session": {
"context": {},
"state": "001"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_CHECK_DATA_USAGE",
"slots": [],
"state": ""
}
},
{
"output": [
{
"assertion": [
{
"type": "ge",
"value": "{%left%},1"
}
],
"result": [
{
"type": "tts",
"value": "您的省内流量2GB,已用流量为{%usage%}GB;剩余流量为{%left%}G,全国流量包1GB,已用流量为1GB,剩余流量为0GB"
}
],
"session": {
"context": {},
"state": "002"
}
},
{
"assertion": [
{
"type": "gt",
"value": "1,{%left%}"
}
],
"result": [
{
"type": "tts",
"value": "您的省内流量2GB,已用流量为{%usage%}GB,剩余流量为{%left%}G;全国流量包1GB,已用流量为1GB,剩余流量为0GB。发现您的流量余额已经不足了,您是否需要续订流量包呢?"
}
],
"session": {
"context": {},
"state": "003"
}
}
],
"params": [
{
"name": "time",
"type": "slot_val",
"value": "user_time"
},
{
"name": "usage",
"type": "func_val",
"value": "demo_get_cellular_data_usage:{%time%}"
},
{
"name": "left",
"type": "func_val",
"value": "demo_get_cellular_data_left:{%time%}"
}
],
"trigger": {
"intent": "INTENT_CHECK_DATA_USAGE",
"slots": [
"user_time"
],
"state": ""
}
},
{
"output": [
{
"assertion": [
{
"type": "ge",
"value": "{%left%},1"
}
],
"result": [
{
"type": "tts",
"value": "您的省内流量2GB,已用流量为{%usage%}GB;剩余流量为{%left%}G,全国流量包1GB,已用流量为1GB,剩余流量为0GB"
}
],
"session": {
"context": {},
"state": "002"
}
},
{
"assertion": [
{
"type": "gt",
"value": "1,{%left%}"
}
],
"result": [
{
"type": "tts",
"value": "您的省内流量2GB,已用流量为{%usage%}GB,剩余流量为{%left%}G;全国流量包1GB,已用流量为1GB,剩余流量为0GB。发现您的流量余额已经不足了,您是否需要续订流量包呢?"
}
],
"session": {
"context": {},
"state": "003"
}
}
],
"params": [
{
"name": "time",
"type": "slot_val",
"value": "user_time"
},
{
"name": "usage",
"type": "func_val",
"value": "demo_get_cellular_data_usage:{%time%}"
},
{
"name": "left",
"type": "func_val",
"value": "demo_get_cellular_data_left:{%time%}"
}
],
"trigger": {
"intent": "INTENT_CHECK_DATA_USAGE",
"slots": [
"user_time"
],
"state": "001"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "王女士您好,我们有省内,全国,夜间三种流量包,您是要续订什么流量包?"
}
],
"session": {
"context": {},
"state": "004"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_YES",
"slots": [],
"state": "003"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "好的王女士,请问还有什么可以帮助您的吗"
}
],
"session": {
"context": {},
"state": "007"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_NO",
"slots": [],
"state": "003"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "好的,{%type%}有如下选择:{%options%}。请问您想续订哪种?"
}
],
"session": {
"context": {},
"state": "005"
}
}
],
"params": [
{
"name": "type",
"type": "slot_val",
"value": "user_package_type"
},
{
"name": "options",
"type": "func_val",
"value": "demo_get_package_options:{%type%}"
}
],
"trigger": {
"intent": "INTENT_BOOK_DATA_PACKAGE",
"slots": [
"user_package_type"
],
"state": "004"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "好的王女士,您是想续订{%name%}的{%type%}吗?"
}
],
"session": {
"context": {},
"state": "006"
}
}
],
"params": [
{
"name": "type",
"type": "slot_val",
"value": "user_package_type"
},
{
"name": "name",
"type": "slot_val",
"value": "user_package_name"
}
],
"trigger": {
"intent": "INTENT_BOOK_DATA_PACKAGE",
"slots": [
"user_package_type",
"user_package_name"
],
"state": "005"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "好的王女士,请问还有什么可以帮助您的吗"
}
],
"session": {
"context": {},
"state": "007"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_YES",
"slots": [],
"state": "006"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "王女士您好,我们有省内,全国,夜间三种流量包,您是要续订什么流量包?"
}
],
"session": {
"context": {},
"state": "004"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_BOOK_DATA_PACKAGE",
"slots": [],
"state": ""
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "王女士您好,我们有省内,全国,夜间三种流量包,您是要续订什么流量包?"
}
],
"session": {
"context": {},
"state": "004"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_BOOK_DATA_PACKAGE",
"slots": [],
"state": "003"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "好的,很高兴为您服务"
}
],
"session": {
"context": {},
"state": "008"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_NO",
"slots": [],
"state": "007"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "{%dmkit_param_last_tts%}"
}
],
"session": {
"context": {},
"state": "{%state%}"
}
}
],
"params": [
{
"name": "state",
"type": "session_state",
"value": ""
}
],
"trigger": {
"intent": "INTENT_REPEAT",
"slots": [],
"state": ""
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "嗯,您说?"
}
],
"session": {
"context": {},
"state": "009"
}
}
],
"params": [
{
"name": "dmkit_param_context_state",
"type": "session_state",
"value": ""
},
{
"name": "dmkit_param_context_tts",
"type": "string",
"value": "{%dmkit_param_last_tts%}"
}
],
"trigger": {
"intent": "INTENT_WAIT",
"slots": [],
"state": ""
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "{%dmkit_param_context_tts%}"
}
],
"session": {
"context": {},
"state": "{%dmkit_param_context_state%}"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_CONTINUE",
"slots": [],
"state": "009"
}
}
]
================================================
FILE: conf/app/demo/cellular_data.xml
================================================
================================================
FILE: conf/app/demo/quota_adjust.json
================================================
[
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": [
"您需要调整临时额度还是固定额度?",
"请问您要调整固定额度还是临时额度呢?"
]
}
],
"session": {
"context": {},
"state": "001"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_ADJUST_QUOTA",
"slots": [],
"state": ""
}
},
{
"output": [
{
"assertion": [
{
"type": "eq",
"value": "{%type%},临时"
}
],
"result": [
{
"type": "tts",
"value": "临时额度只能提升,请问您是否需要提升临时额度?"
}
],
"session": {
"context": {},
"state": "007"
}
},
{
"assertion": [
{
"type": "eq",
"value": "{%type%},固定"
}
],
"result": [
{
"type": "tts",
"value": "您需要提升额度还是降低额度?"
}
],
"session": {
"context": {},
"state": "002"
}
}
],
"params": [
{
"name": "type",
"type": "slot_val",
"value": "user_type"
}
],
"trigger": {
"intent": "INTENT_ADJUST_QUOTA",
"slots": [
"user_type"
],
"state": ""
}
},
{
"output": [
{
"assertion": [
{
"type": "eq",
"value": "{%type%},临时"
}
],
"result": [
{
"type": "tts",
"value": "临时额度只能提升,请问您是否需要提升临时额度?"
}
],
"session": {
"context": {},
"state": "007"
}
},
{
"assertion": [
{
"type": "eq",
"value": "{%type%},固定"
}
],
"result": [
{
"type": "tts",
"value": "您需要提升额度还是降低额度?"
}
],
"session": {
"context": {},
"state": "002"
}
}
],
"params": [
{
"name": "type",
"type": "slot_val",
"value": "user_type"
}
],
"trigger": {
"intent": "INTENT_ADJUST_QUOTA",
"slots": [
"user_type"
],
"state": "001"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "好的王女士,请问您要{%method%}到多少呢?"
}
],
"session": {
"context": {},
"state": "003"
}
}
],
"params": [
{
"name": "method",
"type": "slot_val",
"value": "user_method"
}
],
"trigger": {
"intent": "INTENT_ADJUST_QUOTA",
"slots": [
"user_type",
"user_method"
],
"state": ""
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "好的王女士,请问您要{%method%}到多少呢?"
}
],
"session": {
"context": {},
"state": "003"
}
}
],
"params": [
{
"name": "method",
"type": "slot_val",
"value": "user_method"
}
],
"trigger": {
"intent": "INTENT_ADJUST_QUOTA",
"slots": [
"user_type",
"user_method"
],
"state": "010"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "好的王女士,请问您要{%method%}到多少呢?"
}
],
"session": {
"context": {},
"state": "003"
}
}
],
"params": [
{
"name": "method",
"type": "slot_val",
"value": "user_method"
}
],
"trigger": {
"intent": "INTENT_ADJUST_QUOTA",
"slots": [
"user_type",
"user_method"
],
"state": "002"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "您确认要将银行卡{%type%}{%method%}至{%amount%}吗?"
}
],
"session": {
"context": {},
"state": "004"
}
}
],
"params": [
{
"name": "method",
"type": "slot_val",
"value": "user_method"
},
{
"name": "amount",
"type": "slot_val",
"value": "user_amount"
},
{
"name": "type",
"type": "slot_val",
"value": "user_type"
}
],
"trigger": {
"intent": "INTENT_ADJUST_QUOTA",
"slots": [
"user_type",
"user_method",
"user_amount"
],
"state": "003"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "您确认要将银行卡{%type%}{%method%}至{%amount%}吗?"
}
],
"session": {
"context": {},
"state": "004"
}
}
],
"params": [
{
"name": "method",
"type": "slot_val",
"value": "user_method"
},
{
"name": "amount",
"type": "slot_val",
"value": "user_amount"
},
{
"name": "type",
"type": "slot_val",
"value": "user_type"
}
],
"trigger": {
"intent": "INTENT_ADJUST_QUOTA",
"slots": [
"user_type",
"user_method",
"user_amount"
],
"state": ""
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": [
"很高兴为您服务,还有其他需要帮助的吗?",
"好的,请问还有其他能帮到您的吗?"
]
}
],
"session": {
"context": {},
"state": "005"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_YES",
"slots": [],
"state": "004"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": [
"很高兴为您服务,还有其他需要帮助的吗?",
"好的,请问还有其他能帮到您的吗?"
]
}
],
"session": {
"context": {},
"state": "005"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_YES",
"slots": [],
"state": "011"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "很高兴为您服务,王女士再见!"
}
],
"session": {
"context": {},
"state": "006"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_NO",
"slots": [],
"state": "009"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "很高兴为您服务,王女士再见!"
}
],
"session": {
"context": {},
"state": "006"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_NO",
"slots": [],
"state": "005"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "好的王女士,请问您要提升到多少呢?"
}
],
"session": {
"context": {},
"state": "008"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_YES",
"slots": [],
"state": "007"
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "不好意思,临时额度只能提升,请问还有其他需要帮助的吗?"
}
],
"session": {
"context": {},
"state": "009"
}
}
],
"params": [],
"trigger": {
"intent": "INTENT_NO",
"slots": [],
"state": ""
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "您是要{%method%}固定额度还是临时额度?"
}
],
"session": {
"context": {},
"state": "010"
}
}
],
"params": [
{
"name": "method",
"type": "slot_val",
"value": "user_method"
}
],
"trigger": {
"intent": "INTENT_ADJUST_QUOTA",
"slots": [
"user_method"
],
"state": ""
}
},
{
"output": [
{
"assertion": [],
"result": [
{
"type": "tts",
"value": "您确认要将银行卡临时额度提升至{%amount%}吗?"
}
],
"session": {
"context": {},
"state": "011"
}
}
],
"params": [
{
"name": "amount",
"type": "slot_val",
"value": "user_amount"
}
],
"trigger": {
"intent": "INTENT_ADJUST_QUOTA",
"slots": [
"user_type",
"user_amount"
],
"state": "008"
}
}
]
================================================
FILE: conf/app/demo/quota_adjust.xml
================================================
================================================
FILE: conf/app/products.json
================================================
{
"default": {
"4050": {
"score": 1,
"conf_path": "conf/app/demo/cellular_data.json"
},
"1234": {
"score": 1,
"conf_path": "conf/app/demo/quota_adjust.json"
},
"11669": {
"score": 1,
"conf_path": "conf/app/demo/book_hotel.json"
}
}
}
================================================
FILE: conf/app/remote_services.json
================================================
{
"unit_bot": {
"naming_service_url": "https://aip.baidubce.com",
"load_balancer_name": "",
"protocol": "http",
"client": "brpc",
"timeout_ms": 3000,
"retry": 1,
"headers": {
"Host": "aip.baidubce.com",
"Content-Type": "application/json"
}
},
"token_auth": {
"naming_service_url": "https://aip.baidubce.com",
"load_balancer_name": "",
"protocol": "http",
"client": "brpc",
"timeout_ms": 3000,
"retry": 1,
"headers": {
"Host": "aip.baidubce.com"
}
},
"hotel_service": {
"naming_service_url": "http://127.0.0.1:5000",
"load_balancer_name": "random",
"protocol": "http",
"client": "brpc",
"timeout_ms": 3000,
"retry": 1,
"headers": {
}
}
}
================================================
FILE: conf/gflags.conf
================================================
# TCP Port of this server
--port=8010
# Only allow builtin services at this port
--internal_port=8011
# Connection will be closed if there is no read/write operations during this time
--idle_timeout_s=-1
# Limit of requests processing in parallel
--max_concurrency=0
# Url path of the app
--url_path=/search
# Log to file
--log_to_file=true
================================================
FILE: deps.sh
================================================
#!/usr/bin/env bash
cd "$(dirname "$0")"
JOBS=8
OS=$1
case "$OS" in
mac)
;;
ubuntu)
;;
centos)
;;
none)
;;
*)
echo "Usage: $0 [ubuntu|mac|centos|none]"
exit 1
esac
if [ $OS = mac ]; then
echo "Installing dependencies for MacOS..."
brew install openssl git gnu-getopt gflags protobuf leveldb cmake openssl
elif [ $OS = ubuntu ]; then
echo "Installing dependencies for Ubuntu..."
sudo apt-get install -y git \
g++ \
make \
libssl-dev \
coreutils \
libgflags-dev \
libprotobuf-dev \
libprotoc-dev \
protobuf-compiler \
libleveldb-dev \
libsnappy-dev \
libcurl4-openssl-dev \
libgoogle-perftools-dev
elif [ $OS = centos ]; then
echo "Installing dependencies for CentOS..."
sudo yum install -y epel-release
sudo yum install -y git gcc-c++ make openssl-devel libcurl-devel
sudo yum install -y gflags-devel protobuf-devel protobuf-compiler leveldb-devel gperftools-devel
else
echo "Skipping dependencies installation..."
fi
if [ ! -e brpc ]; then
echo "Cloning brpc..."
git clone https://github.com/brpc/brpc.git
fi
cd brpc
git checkout master
git pull
git checkout 2ae7f04ce513c6aee27545df49d5439a98ae3a3f
#build brpc
echo "Building brpc..."
rm -rf _build
mkdir -p _build
cd _build
cmake ..
make -j$JOBS
cd ../../
================================================
FILE: docs/demo_book_hotel_pattern.txt
================================================
INTENT_BOOK_HOTEL 0.8 [D:user_hotel]#@##0#@##1 1
INTENT_YES 0.7 [D:kw_yes]#@##0#@##1 1
INTENT_NO 0.7 [D:kw_no]#@##0#@##1 1
INTENT_BOOK_HOTEL 0.7 [D:kw_book]#@##0#@##1#@@##[D:kw_hotel]#@##0#@##1#@@##[D:user_time]#@##0#@##0#@@##[D:user_room_type]#@##0#@##0#@@##[D:user_hotel]#@##0#@##0#@@##[D:user_location]#@##0#@##0 1
================================================
FILE: docs/demo_cellular_data_pattern.txt
================================================
INTENT_CONTINUE 0.8 [D:kw_continue]#@##0#@##1 1
INTENT_WAIT 0.8 [D:kw_wait]#@##0#@##1 1
INTENT_CHECK_DATA_USAGE 0.8 [D:user_time]#@##0#@##1 1
INTENT_REPEAT 0.7 [D:kw_repeat]#@##0#@##1 1
INTENT_BOOK_DATA_PACKAGE 0.9 [D:user_package_type]#@##0#@##1 1
INTENT_BOOK_DATA_PACKAGE 0.9 [D:kw_book]#@##0#@##1#@@##[D:kw_package]#@##0#@##0#@@##[D:user_package_type]#@##0#@##0#@@##[D:user_package_name]#@##0#@##0 1
INTENT_CHECK_DATA_USAGE 0.9 [D:kw_check]#@##0#@##1#@@##[D:user_time]#@##0#@##0#@@##[D:kw_data_usage]#@##0#@##1 1
INTENT_NO 0.9 [D:kw_no]#@##0#@##1 1
INTENT_YES 0.9 [D:kw_yes]#@##0#@##1 1
================================================
FILE: docs/demo_quota_adjust_pattern.txt
================================================
INTENT_NO 0.9 [D:kw_no]#@##0#@##1 1
INTENT_YES 0.9 [D:kw_yes]#@##0#@##1 1
INTENT_ADJUST_QUOTA 0.9 [D:kw_adjust]#@##0#@##0#@@##[D:kw_quota]#@##0#@##0#@@##[D:user_type]#@##0#@##0#@@##[D:user_method]#@##0#@##0#@@##[D:user_amount]#@##0#@##0 1
================================================
FILE: docs/demo_skills.md
================================================
# DMKit示例场景
## 查询流量及续订
该场景实现一个简单的手机流量查询及续订流量包的技能。
### 查询流量及续订 - UNIT平台配置
一共分为以下几个步骤:
* 创建技能
* 配置意图
* 配置词槽
* 添加词槽的词典值
* 添加特征词列表
* 导入对话模板
#### 创建BOT
1. 进入[百度理解与交互(unit)平台](http://ai.baidu.com/unit/v2#/sceneliblist)
2. 新建一个技能给查询流量及续订的demo使用。
#### 配置意图
1. 点击进入技能,新建对话意图,以INTENT_CHECK_DATA_USAGE为例。
2. 意图名称填写INTENT_CHECK_DATA_USAGE。
3. 意图别名可以根据自己偏好填写,比如填写查询手机流量。
4. 添加词槽。INTENT_CHECK_DATA_USAGE 所需的词槽为 user_time;
5. 词槽别名可以根据自己偏好填写,比如填写查询时间,查询月份等。
6. 添加完词槽名和词槽别名以后,点击下一步,进入到选择词典环节,user_time选择系统词槽词典中的 **sys_time(时间)** 词槽。
7. 点击下一步,词槽必填选项选择非必填,点击确认。
** 其他两个user_package_type和user_package_name的词槽使用自定义词典,将下方词槽对应的词典内容自行粘贴到txt中上传即可**
全部意图包括列表如下:
* INTENT_CHECK_DATA_USAGE,所需词槽为 user_time
* INTENT_BOOK_DATA_PACKAGE,所需词槽为 user_package_type,user_package_name
* INTENT_YES
* INTENT_NO
* INTENT_REPEAT
* INTENT_WAIT
* INTENT_CONTINUE
词槽列表:
* user_time,使用系统时间词槽
* user_package_type
```text
省内
#省内流量
全国
#全国流量
夜间
#夜间流量
```
* user_package_name
```text
10元100M
#10元100兆
#十元一百M
#十元一百兆
20元300M
#20元300兆
#二十元三百M
#二十元三百兆
50元1G
#五十元1G
```
#### 新增特征词
点击**效果优化**中的**对话模板**,选择下方的新建特征词,将名称和词典值依次填入。eg:kw_check即为名称,名称下方为词典值,直接复制粘贴即可。
特征词列表如下:
* kw_check
```text
查一查
查询
```
* kw_data_usage
```text
流量
流量情况
流量使用情况
流量使用
```
* kw_book
```text
续定
续订
预定
预订
```
* kw_package
```text
流量包
流量套餐
```
* kw_yes
```text
是
好
对
想
要
是的
好的
对的
我想
我要
可以
行的
需要
没问题
```
* kw_no
```text
不
不想
不要
不行
别
没有
没了
没
不用
不需要
没有了
```
* kw_repeat
```text
没听清楚
再说一次
```
* kw_wait
```text
稍等
等等
稍等一下
等一下
等一等
```
* kw_continue
```text
你继续
您继续
继续
```
#### 导入对话模板
* 完成以上步骤后,再进行该步骤,不然系统会报错
* 将文件[demo_cellular_data_pattern.txt](demo_cellular_data_pattern.txt)下载导入即可
### 查询流量及续订 - DMKit配置
该场景DMKit配置为conf/app/demo/cellular_data.json文件,该文件由同目录下对应的.xml文件生成,可以将.xml文件在 中导入查看。
## 调整银行卡额度
该场景实现一个简单的银行卡固定额度及临时额度调整的技能。
### 调整银行卡额度 - UNIT平台配置
平台配置参考查询流量的配置,所需的配置内容见下图。
所需意图包括列表:
* INTENT_ADJUST_QUOTA, 所需词槽为user_type, user_method, user_amount
* INTENT_YES
* INTENT_NO
词槽列表:
* user_amount, 复用系统sys_num词槽
* user_type
```text
固定
#固定额度
临时
#临时额度
```
* user_method
```text
提升
#提高
#增加
降低
#减少
#下调
```
特征词列表:
* kw_adjust
```text
调整
调整一下
改变
```
* kw_quota
```text
额度
银行卡额度
信用卡额度
```
* kw_yes
```text
是
好
对
想
要
是的
好的
对的
我想
我要
可以
行的
没问题
```
* kw_no
```text
不
不想
不要
不对
错了
不是
别
没有
没了
没
不用
没有了
```
对话模板:
* 将文件[demo_quota_adjust_pattern.txt](demo_quota_adjust_pattern.txt)下载导入即可
### 调整银行卡额度 - DMKit配置
该场景DMKit配置为conf/app/demo/quota_adjust.json文件,该文件由同目录下对应的.xml文件生成,可以将.xml文件在 中导入查看。
## 预订酒店
该场景实现一个简单的预订酒店的技能。
### 预订酒店 - UNIT平台配置
平台配置参考查询流量的配置,所需的配置内容见下图。
所需意图包括列表:
* INTENT_BOOK_HOTEL, 所需词槽为user_time, user_room_type, user_location, user_hotel
* INTENT_YES
* INTENT_NO
词槽列表:
* user_time, 复用系统sys_time词槽。需要设置为必填词槽,澄清话术配置为『请问您要预订哪一天的酒店?』
* user_room_type。需要设置为必填词槽,澄清话术配置为『请问您要预订哪个房型?』
```text
标间
大床房
单人间
双人间
双床房
标准房
标准间
```
* user_hotel, 复用系统sys_loc_hotel词槽
* user_location, 复用系统sys_loc词槽
特征词列表:
* kw_book
```text
定
预定
订
预订
```
* kw_hotel
```text
旅馆
酒店
```
* kw_yes
```text
是
好
对
想
要
是的
好的
对的
我想
我要
可以
行的
没问题
确认
```
* kw_no
```text
不
不想
不要
不对
不是
不用
不行
不可以
别
没
没有
没了
没有了
错了
```
对话模板:
* 将文件[demo_book_hotel_pattern.txt](demo_book_hotel_pattern.txt)下载导入即可
### 预订酒店 - DMKit配置
该场景DMKit配置为conf/app/demo/book_hotel.json文件,该文件由同目录下对应的.xml文件生成,可以将.xml文件在 中导入查看。
================================================
FILE: docs/faq.md
================================================
# DMKit 常见问题
## 编译brpc失败
参考BRPC官方文档 ,检查是否已安装所需依赖库。建议使用系统Ubuntu 16.04或者CentOS 7。更多BRPC相关问题请在BRPC github库提起issue。
## 返回错误信息 DM policy resolve failed
DMKit通过UNIT云端技能解析出的用户query的意图和词槽之后,需要根据对话意图结合当前对话状态在DMKit配置中对应的policy处理对话流程。当用户query解析出的意图结合当前对话状态未能找到可选的policy或者选中policy执行流程出错时,DMKit返回错误信息DM policy resolve failed。开发者需要检查:1)当前技能配置是否在products.json文件中进行注册;2)当前query解析结果意图在技能配置中是否配置了policy,详细配置说明参考[DMKit快速上手](tutorial.md);3)检查DMKit日志查看policy执行流程是否出错,。
## 返回错误信息 Failed to call unit bot api
DMKit访问UNIT云端失败。具体原因需要查看DMKit服务日志,常见原因是请求超时。
对于请求超时的情况,先检查DMKit所在服务器网络连接云端(默认地址为 aip.baidubce.com)是否畅通,尝试修改conf/app/remote_services.json文件中unit_bot服务对应超时时间。如果连接没有问题且增大超时时间无效,则尝试切换请求client:DMKit默认使用BRPC client请求UNIT云端,目前发现偶然情况下HTTPS访问云端出现卡死而返回超时错误。DMKit支持切换为curl方式访问云端,将conf/app/remote_services.json配置中client值由brpc修改为curl即可。需要注意使用curl方式时,建议升级openssl版本不低于1.1.0,libcurl版本不低于7.32。
## 返回错误信息 Unsupported action type satisfy
使用DMKit需要将UNIT平台中【技能设置->高级设置】中【对话回应设置】一项设置为『使用DMKit配置』。设置该选项之后,UNIT云端使用DMKit支持的数据协议。如设置为『在UNIT平台上配置』, DMKit无法识别UNIT云端数据协议,将返回错误Unsupported action type satisfy。
## DMKit如何支持FAQ问答对
目前UNIT平台中将【对话回应】设置为【使用DMKit配置】之后,如果对话触发了平台配置的FAQ问答集,平台返回结果不会将答案赋值给response中的say字段,但是会将答案赋值给名为qid的词槽值。因此,结合DMKit配置可以从词槽qid解析出问答回复后进行返回。例如,平台创建问题意图FAQ_HELLO之后,可以在DMKit对应技能的policy配置中添加一下policy支持FAQ_HELLO问答意图下的所有问答集:
```json
{
"trigger": {
"intent": "FAQ_HELLO",
"slots": [],
"state": ""
},
"params": [
{
"name": "answer_list",
"type": "slot_val",
"value": "qid"
},
{
"name": "faq_answer",
"type": "func_val",
"value": "split_and_choose:{%answer_list%},|,random"
}
],
"output": [
{
"assertion": [],
"session": {
"context": {},
"state": "001"
},
"result": [
{
"type": "tts",
"value": "{%faq_answer%}"
}
]
}
]
}
```
================================================
FILE: docs/tutorial.md
================================================
# DMKit 快速上手
## 简介
在任务型对话系统(Task-Oriented Dialogue System)中,一般包括了以下几个模块:
* Automatic Speech Recognition(ASR),即语音识别模块,将音频转化为文本输入。
* Natural Language Understanding(NLU),即自然语言理解模块,通过分析文本输入,解析得到对话意图与槽位(Intent + Slots)。
* Dialog Manager(DM),即对话管理模块,根据NLU模块分析得到的意图+槽位值,结合当前对话状态,执行对应的动作并返回结果。其中执行的动作可能涉及到对内部或外部知识库的查询。
* Natural Language Generation(NLG),即自然语言生成。目前一般采用模板的形式。
* Text To Speech(TTS),即文字转语音模块,将对话系统的文本输出转化为音频。
DMKit关注其中的对话管理模块(Dialog Manager),解决对话系统中状态管理、对话逻辑处理等问题。在实际应用中,单个技能下对话逻辑一般都是根据NLU结果中意图与槽位值,结合当前对话状态,确定需要进行处理的子流程。子流程或者返回固定话术结果,或者根据NLU中槽位值与对话状态访问内部或外部知识库获取资源数据并生成话术结果返回,在返回结果的同时也对对话状态进行更新。我们将这部分对话处理逻辑进行抽象,提供一个通过配置快速构建对话流程,可复用的对话管理模块,即Reusable Dialog Manager。
## 架构

如上图所示,系统核心是一个对话管理引擎,在对话管理引擎的基础上,每个技能的实现都是通过一个配置文件来对对话逻辑和流程进行描述,这样每个技能仅需要关注自身对话逻辑,不需要重复开发框架代码。一个技能的配置包括了一系列的policy,每个policy包括三部分:触发条件(trigger),参数变量(params),以及输出(output)。
* 触发条件(trigger)包括了NLU解析得到的意图+槽位值,以及当前的对话状态,限定了该policy被触发的条件;
* 参数变量(params)是该policy运行需要的一些数据变量的定义,可以包括NLU解析结果中的槽位值、session中的变量值以及函数运行结果等。这里的函数需要在系统中进行注册,例如发送网络请求获取数据这样的函数,这些通用的函数在各个技能间都能共享,特殊情况下个别技能会需要定制化注册自己的函数;
* 输出结果(output)即为该policy运行返回作为对话系统的结果,可以包括话术tts及指令,同时还可对对话状态进行更新以及写入session变量。这里的结果可以使用已定义的参数变量进行模板填充。
在技能基础配置之上,还衍生出了一系列扩展功能。例如我们对一些仅需要触发条件及输出的技能,我们可以设计更精简的语法,使用更简洁的配置描述对话policy;对于多状态跳转的场景,我们引入了可视化的编辑工具,来描述对话跳转逻辑。精简语法表示及可视化编辑都可以自动转化为对话管理引擎可执行的配置,在系统中运行。
## 使用DMKit搭建对话技能的一般步骤
DMKit依托UNIT提供的自然语言理解能力,在此基础上搭建对话技能的一般步骤为:
* 通过UNIT平台创建技能,配置技能对话流程所需的意图解析能力;
* 编写技能的policy配置,policy配置文件语法见[技能配置](#技能配置)。对于对话状态状态繁多,跳转复杂的技能,可以借助[可视化配置工具](visual_tool.md)进行可视化编辑并导出技能配置。
* 将UNIT平台所创建技能的id与其对应policy配置文件注册于DMKit全局注册文件,注册文件配置项见[技能注册](#技能注册)。完成之后编译运行DMKit主程序,访问DMKit[服务接口](#服务接口)即可测试对话效果。
## 详细配置说明
本节详细介绍实现技能功能的配置语法。所有技能的配置均位于模块源码conf/app/目录下。
### 技能注册
products.json为全局注册配置文件,默认采用以"default"为key的配置项,该配置项中每个技能以skill id为key,注册添加技能详细配置,配置字段解释如下:
| 字段 |含义 |
|-----------------|-----------------------------|
|conf_path |技能policy配置文件地址 |
|score |技能排序静态分数,可设置为固定值1 |
### 技能配置
单个技能配置文件包括了一系列policy,每个policy字段说明如下:
| 字段 | 类型 |说明 |
|---------------------------------|--------------|-------------------------------|
|trigger |object | 触发节点,如果一个query满足多个policy的触发条件,则优先取state匹配的policy,再根据slot取覆盖个数最多的 |
|+intent |string | 触发所需NLU意图,意图由UNIT云端对话理解获得。此外,DMKit定义了以下预留意图:
dmkit_intent_fallback: 当云端返回意图在DMKit配置中未找到匹配policy时,DMKit将尝试使用该意图触发policy |
|+slot |list | 触发所需槽位值列表|
|+state |string | 触发所需状态值,即上一轮对话session中保存的state字段值 |
|params |list | 变量列表 |
|+params[].name |string | 变量名,后定义的变量可以使用已定义的变量进行模板填充,result节点中的值也可以使用变量进行模板填充。变量的使用格式为{%name%}。当name=dmkit_param_context_xxx时可以直接将该变量以xxx为key存入本轮session结果context中 |
|+params[].type |string | 变量类型,可能的类型为slot_val,request_param,session_context,func_val等,详细类型列表及说明可参照[params类型及说明](#params中变量类型列表及其说明) |
|+params[].value |string | 变量定义值 |
|+params[].required |bool | 是否必须,如果必须的变量为空值时,该policy将不会返回结果 |
|output |list | 返回结果节点,可定义多个output,最终输出会按顺序选择第一个满足assertion条件的output |
|+output[].assertion |list | 使用该output的前提条件列表 |
|+output[].assertion[].type |string | 条件类型,详细列表及说明可参照[assertion类型及说明](#result中assertion类型说明) |
|+output[].assertion[].value |string | 条件值 |
|+output[].session |object | 需要保存的session数据,用于更新对话状态及记录上下文 |
|+output[].session.state |string | 更新的对话状态值 |
|+output[].session.context |kvdict | 写入session的变量节点,该节点下的key+value数据会存入session,再下一轮中可以在变量定义中使用 |
|+output[].result |list | 返回结果中result节点,多个result作为数组元素一起返回 |
|+output[].result[].type |string | result类型 |
|+output[].result[].value |string | result值 |
#### params中变量类型列表及其说明:
| type |说明 |
|----------|--------------|
| slot_val | 从qu结果中取对应的slot值,有归一化值优先取归一化值。当对应tag值存在多个slot时,value值支持tag后按分隔符","添加下标i取对应tag的第i个值(索引从0开始) |
| request_param | 取请求参数对应的字段。这里的请求参数对应请求数据request.client_session字段中包含的K-V对,仅支持V类型为string的参数。例如request.client_session字段值为{"param_name1": "param_value1", "param_name2": "param_value2"},定义type为request_param,value为"param_name1"的变量,该变量将赋值为"param_value1" |
| session_context | 上一轮对话session结果中context结构体中对应的字段,例如上一轮output中session结构体保存了变量: {"context": {"param_name": "{%param_name%}", "state": ""}}, 本轮可定义变量{"name": "param_name", "type": "session_context", "value": "param_name"} |
| func_val | 调用开发者定义的函数。用户定义函数位于src/user_function目录下,并需要在user_function_manager.cpp文件中进行注册。value值为","连接的参数,其中第一个元素为函数名,第二个元素开始为函数参数 |
| qu_intent | NLU结果中的intent值 |
| session_state | 当前对话session中的state值 |
| string | 字符串值,可以使用已定义变量进行模板填充 |
特别的,开发者可以添加注册自定义函数,定义func_val类型的变量调用自定义函数实现功能扩展、定制化对话逻辑。DMKit默认内置提供了包括以下函数:
| 函数名 |函数说明 | 参数 |
|----------------|--------------|----------------------|
| service_http_get | 通过HTTP GET的方式请求知识库、第三方API等服务,服务地址需配置于conf/app/remote_services.json中 |参数1:remote_services.json中配置的服务名
参数2:服务请求的路径,例如"/baidu/unit-dmkit" |
| service_http_post | 通过HTTP POST的方式请求知识库、第三方API等服务,服务地址需配置于conf/app/remote_services.json中。注意:如果请求路径包含中文,需要先对中文进行URL编码后再拼接URL |参数1:remote_services.json中配置的服务名
参数2:服务请求的路径,例如"/baidu/unit-dmkit"
参数3:POST数据内容 |
| json_get_value | 根据提供的路径从json字符串中获取对应的字段值 |参数1:json字符串
参数2:所需获取的字段在json字符串中的路径。例如{"data":{"str":"hello", "arr":[{"str": "world"}]}}中路径data.str对应字段值为"hello", 路径data.arr.0.str对应字段值"world"。|
| url_encode | 对输入字符串进行url编码操作 |参数1:进行编码的字符串|
另外,DMKit默认定义提供以下变量:
| 变量名 | 说明 |
|---------------------------------|--------------------------------------------------------------------------|
| dmkit_param_last_tts | 上一轮返回结果result中第一个type为tts的元素value值,如果不存在则为空字符串 |
| dmkit_param_context_xxxxx | 上一轮session结果context中key为xxxx的值,同时如果用户定义了名为dmkit_param_context_xxxxx的变量,dmkit自动将该变量以xxxxx为key存入本轮session结果context|
| dmkit_param_slot_xxxxx | qu结果中tag为xxxxx的slot值, 如果存在多个相同tag的slot,则取第一个|
#### result中assertion类型说明:
| type |说明 |
|----------|--------------------------------------------------|
| empty | value值为空 |
| not_empty| value值非空 |
| in | value值以","切分,第一个元素在从第二个元素开始的列表中 |
| not_in | value值以","切分,第一个元素不在从第二个元素开始的列表中 |
| eq | value值以","切分,第一个元素等于第二个元素 |
| gt | value值以","切分,第一个数字大于第二个数字 |
| ge | value值以","切分,第一个数字大于等于第二个数字 |
### 精简语法及可视化配置
* 在默认基础配置之上,有能力的开发者可以自行设计使用更简洁的配置描述对话policy并转化为基础配置进行加载。
* 对于多状态跳转的场景,可以引入了可视化的编辑工具,来描述对话跳转逻辑。这里我们提供了一个使用[mxgraph](https://github.com/jgraph/mxgraph)进行可视化配置的样例,文档参考:[可视化配置工具](visual_tool.md)
## DMKit服务接口
* DMKit服务监听端口及访问路径等参数可通过conf/gflags.conf文件进行配置,默认请求链接为http://:8010/search, 其中为DMKit服务所在机器IP,请求方式为POST
* 服务接收POST数据协议与[UNIT2.0接口协议](http://ai.baidu.com/docs#/UNIT-v2-API/top)兼容。开发者按照协议组装JSON数据请求DMKit,DMKit按照该协议返回JSON数据,同时DMKit定义返回结果action_list中custom_reply类型为DM_RESULT时,返回内容为DMKit输出的output结果。
================================================
FILE: docs/visual_tool.md
================================================
# 对话流可视化配置
针对状态繁多、跳转复杂的垂类,DMKit支持通过可视化编辑工具进行状态跳转流程的编辑设计,并同步转化为对话基础配置供对话管理引擎加载执行。
## 基于开源可视化工具的配置样例
这里的可视化编辑工具使用开源的[mxgraph](https://github.com/jgraph/mxgraph)可视化库,对话开发者可在可视化工具上进行图编辑,而该可视化库支持从图转化为xml文件,我们再利用转换框架实现对应的编译器将xml文件转化为对话基础配置加载执行。以demo场景【查询流量及续订】为例,步骤为:
* 在[draw.io](https://www.draw.io/)中按照[编辑规则](#编辑规则)进行图编辑
* 在编辑好的图导出为xml文件,放置于conf/app/demo目录下
* 运行language_compiler/run.py程序,该程序调用对应的转换器将conf/app/demo目录下的xml文件转化为json文件
* 将json文件注册于conf/app/products.json文件后,运行DMKit加载执行
### 编辑规则
规定使用以下构图元素进行编辑:
* 单箭头连线,单箭头连线是路程图中最基本的元素之一,用来表示状态的跳转。注意在使用连线的时候,连线的两端需要出现蓝色的 x 标识,以确保这个连线成功连接了两个框。
* 椭圆,用户节点,椭圆中存放的是用户的意图,以及槽位值(可选),内部语言格式为:
```text
INTENT: intent_xxx
SLOT:user_a,user_b
```
该节点表示用户输入query的NLU解析结果,结合指向该节点的BOT节点,构成了DMKit基础配置中一个完成trigger条件
* 圆角矩形,BOT节点,圆角矩形中存放的是BOT的回复,内部格式为:
```text
PARAM:param_type:param_name1=param_value2
PARAM:param_type:param_name1=param_value2
BOT:XXXXXXX{%param_name1%}XXX{%param_name2%}
```
该节点表示BOT应该执行的回复,同时节点中可以定义参数并对回复进行模板填充。
在这里定义`dmkit_param_context_xxxx`变量时,dmkit自动将该变量以`xxxx`为key存入本轮session结果context。下一轮可以定义type=session_context,value=xxxx的变量来读取,也可以直接使用value={%dmkit_param_context_xxxx%}来获取,具体可参考[params类型及说明](tutorial.md#params中变量类型列表及其说明)
* 菱形,条件节点,在节点中可定义需要进行判断的变量:
```text
PARAM:param_type:param_name1=param_value2
PARAM:param_type:param_name1=param_value2
```
同时对该节点连出的单箭头连线可以添加描述跳转条件,条件可使用在菱形中定义的变量,例如:
```text
ge:{%param1%},1
```
跳转条件描述中,可用&&或||来连接多个条件以表达“和”或“或”,例如:
```text
ge:{%param1%},1 || eq: {%param1%}, 10
```
||和&&不可同时出现在一个条件描述中。
另外规定:
* 一个椭圆仅可以连向一个圆角矩形或者一个菱形
* 一个圆角矩形可以连向多个椭圆
* 一个菱形可以连向多个圆角矩形
详细使用示例参考conf/app/demo目录下demo场景的xml文件,该xml文件可在[draw.io](https://www.draw.io/)中导入查看
================================================
FILE: language_compiler/compiler_xml.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Parser for .xml file
"""
import xml.etree.ElementTree as ET
import re
import codecs
import json
import sys
class Node(object):
"""
this class is used to denote the node in .xml file, which will be used in XmlParser
"""
def __init__(self, node_id, node_type, node_status, node_nlg, state, params):
self.node_id = node_id
self.node_type = node_type
self.node_status = node_status # intent for user node, empty for server node
self.node_nlg = node_nlg # slots for customer node, nlg for server node
self.state = state
self.from_nodes = [] # (node_type, node_id, arrow_value)
self.to_nodes = [] # (node_type, node_id, arrow_value)
self.params = params # parameters for server node
class Arrow(object):
"""
this class is used to denote the arrow in .xml file, which will be used in XmlParser
"""
def __init__(self, arrow_source, arrow_target):
self.arrow_source = arrow_source # cell id
self.arrow_target = arrow_target # cell id
self.arrow_text = ""
class XmlParser(object):
"""
this class is used to represent the graph in .xml file as a class object,
all necessary information about the graph will be extracted from the .xml file.
"""
def __init__(self, xml_data):
# initialize the elementtree object
self.root = ET.fromstring("".join(xml_data))
# collections of nodes, arrows, policies and nlg responses.
self.__nodes = {}
self.__arrows = {}
self.__nlgs = []
self.__user_nodes = []
self.__policies = [] # the output json
# methods to parse the element tree
self.__parse_cells() # parse every cell to extract info of node and arrow
self.__connect_nodes() # connect nodes with the info from arrow
self.__extract_policy()
def __clean_noise(self, value, stri='', noise=None):
if not noise:
noise = [' ', ' ', 'nbsp', '
', '', '
', '', '', '']
for n in noise:
value = value.replace(n, stri)
if not stri:
value = value.replace('&', '&')
return value.strip()
def __clean_span(self, value, stri=''):
return self.__clean_noise(value, '', ['', '', ''])
def __parse_cells(self):
todo_cells = []
state = 0
for cell in self.root[0]:
style = cell.get('style')
cell_id = cell.get('id')
# bot node
if style and style.startswith('rounded=1'):
value = cell.get('value') + '<'
value = self.__clean_span(value)
re_nlg = re.compile(r'BOT.*?<')
re_params = re.compile(r'PARAM.*?<')
re_state = re.compile(r'STATE.*?<')
nlgs = []
if re_nlg.findall(value):
for n in re_nlg.findall(value):
n = self.__clean_noise(n[4:-1])
nlgs.append(n)
params = []
if re_params.findall(value):
for p in re_params.findall(value):
p = self.__clean_noise(p[6:-1])
params.append(p)
node_state = None
if re_state.findall(value):
for s in re_state.findall(value):
node_state = self.__clean_noise(s[6:-1])
break
if 'BOT' not in value:
raise Exception('wrong shape is used in cell with text: ', nlg)
if value:
if not node_state:
state += 1
node_state = (3 - len(str(state))) * '0' + str(state)
node = Node(cell_id, 'server', '', nlgs, node_state, params)
self.__nodes[cell_id] = node
continue
# customer node
if style and style.startswith('ellipse'):
value = cell.get('value') + '<'
value = self.__clean_span(value)
re_status = re.compile(r'INTENT.*?<')
status = re_status.findall(value)[0]
status = self.__clean_noise(status[status.find(':') + 1:-1])
re_nlg = re.compile(r'SLOT.*?<')
if re_nlg.findall(value):
nlg = self.__clean_noise(re_nlg.findall(value)[0][5:-1])
else:
nlg = ""
if 'INTENT' not in value:
raise Exception('wrong shape is used in cell with text: ', nlg)
if value:
node = Node(cell_id, 'customer', status, nlg, '', [])
self.__nodes[cell_id] = node
self.__user_nodes.append(node)
continue
# arrow node
if cell.get('edge'):
source = cell.get('source')
target = cell.get('target')
text = cell.get('value')
if not target:
raise Exception('some arrow is not pointing to anything')
else:
arrow = Arrow(source, target)
if text:
arrow.arrow_text = self.__clean_noise(text)
self.__arrows[cell_id] = arrow
continue
# judge node
if style and style.startswith('rhombus'):
value = cell.get('value') + '<'
value = self.__clean_span(value)
re_params = re.compile(r'PARAM.*?<')
params = []
if re_params.findall(value):
for p in re_params.findall(value):
p = self.__clean_noise(p[6:-1])
params.append(p)
if 'PARAM' not in value:
raise Exception('wrong shape is used in cell with text: ', value)
if value:
node = Node(cell_id, 'judge', '', '', '', params)
self.__nodes[cell_id] = node
continue
todo_cells.append(cell)
# false initial
self.__false_initial = Node(-1, "", "", "", "", "")
self.__nodes[-1] = self.__false_initial
for cell in todo_cells:
style = cell.get('style')
parent_id = cell.get('parent')
value = cell.get('value')
if style and style.startswith('text'):
if parent_id in self.__arrows:
self.__arrows[parent_id].arrow_text = self.__clean_noise(value)
else:
raise Exception("there is a bad cell with text", value)
def __connect_nodes(self):
for (arrow_id, arrow) in self.__arrows.items():
if not arrow.arrow_source:
source_node = self.__false_initial
else:
source_node = self.__nodes[arrow.arrow_source]
target_node = self.__nodes[arrow.arrow_target]
# update source node and target node
source_node.to_nodes.append((target_node.node_type, target_node.node_id,
arrow.arrow_text))
self.__nodes[arrow.arrow_source] = source_node
target_node.from_nodes.append((source_node.node_type,
source_node.node_id, arrow.arrow_text))
self.__nodes[arrow.arrow_target] = target_node
def __extract_policy(self):
for node in self.__user_nodes:
intent = node.node_status
slots = node.node_nlg.replace(' ', '').split(',')
if slots[0] == "":
slots = []
params = []
output = []
if len(node.to_nodes) > 0:
dir_to_node_info = node.to_nodes[0]
dir_to_node = self.__nodes[dir_to_node_info[1]]
# points to one branch
if dir_to_node.node_type == "server":
for p in dir_to_node.params:
p = p.replace(" ", "")
c1 = p.find(":")
c2 = p.find("=")
pname = p[c1 + 1:c2]
ptype = p[0:c1]
pvalue = p[c2 + 1:]
param = {"name":pname, "type":ptype, "value":pvalue}
params.append(param)
next = {}
next["assertion"] = []
next["session"] = {"state": dir_to_node.state, "context": {}}
results = []
for n in dir_to_node.node_nlg:
result = {}
result["type"] = "tts"
alt = n.replace(" ", "").split('|')
if len(alt) == 1:
result["value"] = alt[0]
else:
result["value"] = alt
results.append(result)
next["result"] = results
output.append(next)
# points to multiple branch
elif dir_to_node.node_type == "judge":
# add params in judge nodes
for p in dir_to_node.params:
p = p.replace(" ", "")
c1 = p.find(":")
c2 = p.find("=")
pname = p[c1 + 1:c2]
ptype = p[0:c1]
pvalue = p[c2 + 1:]
param = {"name":pname, "type":ptype, "value":pvalue}
params.append(param)
# extract policies from bot nodes pointed to by the judge node
for to_node_info in dir_to_node.to_nodes:
to_node = self.__nodes[to_node_info[1]]
nexts = []
condition = to_node_info[2].replace(" ", "")
# A condition can contain either '&&'s or '||'s, but not both
if '&&' in condition:
conditions = condition.split('&&')
assertions = []
for cond in conditions:
cut = cond.find(":")
assertion = {}
assertion["type"] = cond[0:cut]
assertion["value"] = cond[cut + 1:]
assertions.append(assertion)
next = {}
next["assertion"] = assertions
next["session"] = {"state": to_node.state, "context": {}}
nexts.append(next)
elif '||' in condition:
conditions = condition.split('||')
for cond in conditions:
next = {}
assertion = {}
cut = cond.find(":")
assertion["type"] = cond[0:cut]
assertion["value"] = cond[cut + 1:]
next["assertion"] = [assertion]
next["session"] = {"state": to_node.state, "context": {}}
nexts.append(next)
else:
assertion = {}
cut = condition.find(":")
assertion["type"] = condition[0:cut]
assertion["value"] = condition[cut + 1:]
next = {}
next["assertion"] = [assertion]
next["session"] = {"state": to_node.state, "context": {}}
nexts.append(next)
results = []
for n in to_node.node_nlg:
result = {}
result["type"] = "tts"
alt = n.replace(" ", "").split('|')
if len(alt) == 1:
result["value"] = alt[0]
else:
result["value"] = alt
results.append(result)
for nxt in nexts:
nxt["result"] = results
output.append(nxt)
for p in to_node.params:
p = p.replace(" ", "")
c1 = p.find(":")
c2 = p.find("=")
pname = p[c1 + 1:c2]
ptype = p[0:c1]
pvalue = p[c2 + 1:]
param = {"name":pname, "type":ptype, "value":pvalue}
# merge
if param not in params:
params.append(param)
for from_node_info in node.from_nodes:
if from_node_info:
policy = {}
from_node = self.__nodes[from_node_info[1]]
trigger = {}
trigger["intent"] = intent
trigger["slots"] = slots
trigger["state"] = from_node.state
policy["trigger"] = trigger
policy["params"] = params
policy["output"] = output
self.__policies.append(policy)
def write_json(self):
"""
this method is used to return the parsed json for the .xml input
"""
return [json.dumps(self.__policies, ensure_ascii=False,\
indent=2, sort_keys=True).encode('utf8')]
def run(data):
"""
runs the parser and returns the parsed json
"""
ps = XmlParser(data)
return ps.write_json()
================================================
FILE: language_compiler/run.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Compiler Executor.
"""
import codecs
import ConfigParser
import importlib
import os
import sys
def main():
"""
Compiler Executor
process files in data_path and generate system json configuration file
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
config = ConfigParser.ConfigParser()
config.read(current_dir + "/settings.cfg")
compile_types = config.get('compiler', 'compile_types')
compile_types = compile_types.replace(' ', '').split(',')
data_path = current_dir + '/' + config.get('data', 'data_path')
for dirpath, dirnames, filenames in os.walk(data_path):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
fname, extension = os.path.splitext(filename)
extension = extension[1:]
if extension not in compile_types:
continue
input_lines = []
with open(filepath, 'r') as f:
input_lines = f.readlines()
compiler_module = importlib.import_module('compiler_' + extension)
compiler_function = getattr(compiler_module, 'run')
output_lines = compiler_function(input_lines)
with open(os.path.join(dirpath, fname + ".json"), 'w') as f:
for line in output_lines:
f.write("%s\n" % line)
if __name__ == "__main__":
main()
================================================
FILE: language_compiler/settings.cfg
================================================
[data]
data_path=../conf/app
[compiler]
compile_types=xml
exclude_file=
================================================
FILE: proto/http.proto
================================================
syntax="proto2";
package dmkit;
option cc_generic_services = true;
message HttpRequest { };
message HttpResponse { };
service HttpService {
rpc run(HttpRequest) returns (HttpResponse);
};
================================================
FILE: src/app_container.cpp
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "app_container.h"
#include
#include "app_log.h"
#include "brpc.h"
#include "dialog_manager.h"
namespace dmkit {
ThreadLocalDataFactory::ThreadLocalDataFactory(ApplicationBase* application) {
this->_application = application;
}
ThreadLocalDataFactory::~ThreadLocalDataFactory() {
this->_application = nullptr;
}
void* ThreadLocalDataFactory::CreateData() const {
return this->_application->create_thread_data();
}
void ThreadLocalDataFactory::DestroyData(void* d) const {
this->_application->destroy_thread_data(d);
}
AppContainer::AppContainer() {
this->_application = nullptr;
this->_data_factory = nullptr;
}
AppContainer::~AppContainer() {
delete this->_application;
this->_application = nullptr;
delete this->_data_factory;
this->_data_factory = nullptr;
}
int AppContainer::load_application() {
if (nullptr != this->_application) {
APP_LOG(ERROR) << "an application has already be loaded";
return -1;
}
// The real application is created here
this->_application = new DialogManager();
if (nullptr == this->_application || 0 != this->_application->init()) {
APP_LOG(ERROR) << "failed to init application!!!";
return -1;
}
this->_data_factory = new ThreadLocalDataFactory(this->_application);
return 0;
}
ThreadLocalDataFactory* AppContainer::get_thread_local_data_factory() {
if (nullptr == this->_data_factory) {
APP_LOG(ERROR) << "Data factory has not been initialized!!!";
return nullptr;
}
return this->_data_factory;
}
int AppContainer::run(BRPC_NAMESPACE::Controller* cntl) {
if (nullptr == this->_application) {
APP_LOG(ERROR) << "No application is not loaded for processing!!!";
return -1;
}
auto time_start = std::chrono::steady_clock::now();
// Need to reset thread data status before running the application
ThreadDataBase* tls = static_cast(BRPC_NAMESPACE::thread_local_data());
tls->reset();
APP_LOG(TRACE) << "Running application";
int result = this->_application->run(cntl);
auto time_end = std::chrono::steady_clock::now();
std::chrono::duration diff = std::chrono::duration_cast>(time_end - time_start);
double total_cost = diff.count() * 1000;
APP_LOG(TRACE) << "Application run cost(ms): " << total_cost;
tls->add_notice_log("tm", std::to_string(total_cost));
APP_LOG(NOTICE) << tls->get_notice_log();
return result;
}
} // namespace dmkit
================================================
FILE: src/app_container.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_APP_CONTAINER_H
#define DMKIT_APP_CONTAINER_H
#include "application_base.h"
#include "brpc.h"
namespace dmkit {
class ThreadLocalDataFactory : public BRPC_NAMESPACE::DataFactory {
public:
ThreadLocalDataFactory(ApplicationBase* application);
~ThreadLocalDataFactory();
void* CreateData() const;
void DestroyData(void* d) const;
private:
ApplicationBase* _application;
};
// Container class which manages the instances of application,
// as well as a thread data factory instance.
class AppContainer {
public:
AppContainer();
~AppContainer();
int load_application();
ThreadLocalDataFactory* get_thread_local_data_factory();
int run(BRPC_NAMESPACE::Controller* cntl);
private:
// The application instance is shared for all rpc threads
ApplicationBase* _application;
ThreadLocalDataFactory* _data_factory;
};
} // namespace dmkit
#endif //DMKIT_APP_CONTAINER_H
================================================
FILE: src/app_log.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_APP_LOG_H
#define DMKIT_APP_LOG_H
#include "brpc.h"
#include "butil.h"
#include "thread_data_base.h"
namespace dmkit {
// Wrapper for application logging to include trace id for each log during a request.
#define APP_LOG(severity) \
LOG(severity) << "logid=" << (BRPC_NAMESPACE::thread_local_data() == nullptr ? "" : \
(static_cast(BRPC_NAMESPACE::thread_local_data()))->get_log_id()) \
<< " "
} // namespace dmkit
#endif //DMKIT_APP_LOG_H
================================================
FILE: src/application_base.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_APPLICATION_BASE_H
#define DMKIT_APPLICATION_BASE_H
#include "brpc.h"
#include "thread_data_base.h"
namespace dmkit {
// Base class for applications
// Notice that the same application instance will be used for all request thread,
// This call should be thread safe.
class ApplicationBase {
public:
ApplicationBase() {};
virtual ~ApplicationBase() {};
// Interface for application to do global initialization.
// This method will be invoke only once when server starts.
virtual int init() = 0;
// Interface for application to handle requests, it should be thread safe.
virtual int run(BRPC_NAMESPACE::Controller* cntl) = 0;
// Interface for application to register customized thread data.
virtual void* create_thread_data() const {
return new ThreadDataBase();
}
// Interface to destroy thread data, the data instance was created by create_thread_data.
virtual void destroy_thread_data(void* d) const {
delete static_cast(d);
}
// Set log id for current request,
// application should set log id as early as possible when processing request
virtual void set_log_id(const std::string& log_id) {
ThreadDataBase* tls = static_cast(BRPC_NAMESPACE::thread_local_data());
tls->set_log_id(log_id);
}
// Add a key/value notice log which will be logged when finish processing request
virtual void add_notice_log(const std::string& key, const std::string& value) {
ThreadDataBase* tls = static_cast(BRPC_NAMESPACE::thread_local_data());
tls->add_notice_log(key, value);
}
};
} // namespace dmkit
#endif //DMKIT_APPLICATION_BASE_H
================================================
FILE: src/brpc.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_BRPC_H
#define DMKIT_BRPC_H
#ifndef BRPC_INCLUDE_PREFIX
#define BRPC_INCLUDE_PREFIX
#include BRPC_INCLUDE_PREFIX/channel.h>
#include BRPC_INCLUDE_PREFIX/controller.h>
#include BRPC_INCLUDE_PREFIX/restful.h>
#include BRPC_INCLUDE_PREFIX/server.h>
#endif //DMKIT_BRPC_H
================================================
FILE: src/butil.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_BUTIL_H
#define DMKIT_BUTIL_H
#ifndef BUTIL_INCLUDE_PREFIX
#define BUTIL_INCLUDE_PREFIX
#ifdef BUTIL_ENABLE_COMLOG_SINK
#include BUTIL_INCLUDE_PREFIX/comlog_sink.h>
#endif
#include BUTIL_INCLUDE_PREFIX/containers/flat_map.h>
#include BUTIL_INCLUDE_PREFIX/logging.h>
#endif //DMKIT_BUTIL_H
================================================
FILE: src/dialog_manager.cpp
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dialog_manager.h"
#include
#include
#include "app_log.h"
#include "butil.h"
#include "rapidjson.h"
#include "request_context.h"
#include "utils.h"
namespace dmkit {
DialogManager::DialogManager() {
this->_remote_service_manager = new RemoteServiceManager();
this->_policy_manager = new PolicyManager();
this->_token_manager = new TokenManager();
}
DialogManager::~DialogManager() {
delete this->_policy_manager;
this->_policy_manager = nullptr;
delete this->_remote_service_manager;
this->_remote_service_manager = nullptr;
delete this->_token_manager;
this->_token_manager = nullptr;
}
int DialogManager::init() {
if (0 != this->_remote_service_manager->init("conf/app", "remote_services.json")) {
APP_LOG(ERROR) << "Failed to init _remote_service_manager";
return -1;
}
APP_LOG(TRACE) << "_remote_service_manager init done";
if (0 != this->_policy_manager->init("conf/app", "products.json")) {
APP_LOG(ERROR) << "Failed to init _policy_manager";
return -1;
}
if (0 != this->_token_manager->init("conf/app", "bot_tokens.json")) {
APP_LOG(ERROR) << "Failed to init _token_manager";
return -1;
}
APP_LOG(TRACE) << "_policy_manager init done";
return 0;
}
int DialogManager::run(BRPC_NAMESPACE::Controller* cntl) {
std::string request_json = cntl->request_attachment().to_string();
APP_LOG(TRACE) << "received request: " << request_json;
this->add_notice_log("req", request_json);
rapidjson::Document request_doc;
// In the case we cannot parse the request json, it is not a valid request.
if (request_doc.Parse(request_json.c_str()).HasParseError() || !request_doc.IsObject()) {
APP_LOG(WARNING) << "Failed to parse request data to json";
cntl->http_response().set_status_code(400);
return 0;
}
// Need to set log_id as soon as we can get it to avoid missing log_id in logs.
if (!request_doc.HasMember("log_id") || !request_doc["log_id"].IsString()) {
APP_LOG(WARNING) << "Missing log_id";
this->send_json_response(cntl, this->get_error_response(-1, "Missing log_id"));
return 0;
}
std::string log_id = request_doc["log_id"].GetString();
std::string dmkit_log_id_prefix = "dmkit_";
log_id = dmkit_log_id_prefix + log_id;
this->set_log_id(log_id);
request_doc["log_id"].SetString(log_id.c_str(), log_id.length(), request_doc.GetAllocator());
// Parsing bot_id from request json
if (!request_doc.HasMember("bot_id") || !request_doc["bot_id"].IsString()) {
APP_LOG(WARNING) << "Missing bot_id";
this->send_json_response(cntl, this->get_error_response(-1, "Missing bot_id"));
return 0;
}
std::string bot_id = request_doc["bot_id"].GetString();
if (!request_doc.HasMember("request") || !request_doc["request"].HasMember("query")
|| !request_doc["request"]["query"].IsString()) {
APP_LOG(WARNING) << "Missing query";
this->send_json_response(cntl, this->get_error_response(-1, "Missing query"));
return 0;
}
std::string query = request_doc["request"]["query"].GetString();
std::string rewrite_query;
if (request_doc["request"].HasMember("rewrite_query")
&& request_doc["request"]["rewrite_query"].IsString()) {
rewrite_query = request_doc["request"]["rewrite_query"].GetString();
request_doc["request"].RemoveMember("rewrite_query");
}
// Get dmkit session from request. We saved it in the bot session in latest response.
std::string dm_session;
if (request_doc.HasMember("bot_session")) {
std::string request_bot_session = request_doc["bot_session"].GetString();
rapidjson::Document request_bot_session_doc;
if (request_bot_session.empty()
|| request_bot_session_doc.Parse(request_bot_session.c_str()).HasParseError()
|| !request_bot_session_doc.IsObject()
|| !request_bot_session_doc.HasMember("bot_id")
|| !request_bot_session_doc["bot_id"].IsString()
|| request_bot_session_doc["bot_id"].GetString() != bot_id
|| !request_bot_session_doc.HasMember("session_id")
|| !request_bot_session_doc.HasMember("dialog_state")
|| !request_bot_session_doc["dialog_state"].HasMember("contexts")
|| !request_bot_session_doc["dialog_state"]["contexts"].HasMember("dmkit")
|| !request_bot_session_doc["dialog_state"]["contexts"]["dmkit"].HasMember("session")) {
// Not a valid session from DMKit
request_doc["bot_session"].SetString("", 0, request_doc.GetAllocator());
} else {
dm_session =
request_bot_session_doc["dialog_state"]["contexts"]["dmkit"]["session"].GetString();
}
}
APP_LOG(TRACE) << "dm session: " << dm_session;
PolicyOutputSession session = PolicyOutputSession::from_json_str(dm_session);
// Get access_token from request uri.
std::string access_token;
const std::string* access_token_ptr = cntl->http_request().uri().GetQuery("access_token");
if (access_token_ptr != nullptr) {
access_token = *access_token_ptr;
}
if (access_token.empty() && this->_token_manager->get_access_token(
bot_id, this->_remote_service_manager, access_token) != 0) {
APP_LOG(ERROR) << "Failed to get access token";
this->send_json_response(cntl, this->get_error_response(-1, "Failed to get access token"));
return 0;
}
std::string query_response;
bool is_dmkit_response = false;
this->process_request(request_doc, dm_session, access_token, query_response, is_dmkit_response);
if (is_dmkit_response || rewrite_query.empty()) {
this->send_json_response(cntl, query_response);
return 0;
}
std::string rewrite_query_response;
request_doc["request"]["query"].SetString(rewrite_query.c_str(),
rewrite_query.length(), request_doc.GetAllocator());
this->process_request(request_doc, dm_session, access_token,
rewrite_query_response, is_dmkit_response);
if (is_dmkit_response) {
this->send_json_response(cntl, rewrite_query_response);
return 0;
}
this->send_json_response(cntl, query_response);
return 0;
}
int DialogManager::process_request(const rapidjson::Document& request_doc,
const std::string& dm_session,
const std::string& access_token,
std::string& json_response,
bool& is_dmkit_response) {
std::string bot_id = request_doc["bot_id"].GetString();
std::string log_id = request_doc["log_id"].GetString();
std::string query = request_doc["request"]["query"].GetString();
is_dmkit_response = false;
std::string request_json = utils::json_to_string(request_doc);
// Call unit bot api with the request json as dmkit use the same data contract.
std::string unit_bot_result;
if (this->call_unit_bot(access_token, request_json, unit_bot_result) != 0) {
APP_LOG(ERROR) << "Failed to call unit bot api";
json_response = get_error_response(-1, "Failed to call unit bot api");
return 0;
}
APP_LOG(TRACE) << "unit bot result: " << unit_bot_result;
// Parse unit bot response.
// In the case something wrong with unit bot response, informs users.
rapidjson::Document unit_response_doc;
if (unit_response_doc.Parse(unit_bot_result.c_str()).HasParseError()
|| !unit_response_doc.IsObject()) {
APP_LOG(ERROR) << "Failed to parse unit bot result: " << unit_bot_result;
json_response = get_error_response(-1, "Failed to parse unit bot result");
return -1;
}
if (!unit_response_doc.HasMember("error_code")
|| !unit_response_doc["error_code"].IsInt()
|| unit_response_doc["error_code"].GetInt() != 0) {
json_response = unit_bot_result;
return 0;
}
// The bot status is included in bot_session
std::string bot_session = unit_response_doc["result"]["bot_session"].GetString();
rapidjson::Document bot_session_doc;
if (bot_session_doc.Parse(bot_session.c_str()).HasParseError()
|| !bot_session_doc.IsObject()) {
APP_LOG(ERROR) << "Failed to parse bot session: " << bot_session;
json_response = get_error_response(-1, "Failed to parse bot session");
return 0;
}
if (this->handle_unsatisfied_intent(unit_response_doc,
bot_session_doc, dm_session, json_response) == 0) {
return 0;
}
// Handle satify/understood intents
QuResult* qu_result = QuResult::parse_from_dialog_state(
bot_id, bot_session_doc["dialog_state"]);
if (qu_result == nullptr) {
json_response = this->get_error_response(-1, "Failed to parse qu_result");
return 0;
}
auto qu_map = new BUTIL_NAMESPACE::FlatMap();
// 2: bucket_count, initial count of buckets, big enough to avoid resize.
// 50: load_factor, element_count * 100 / bucket_count.
qu_map->init(2, 50);
qu_map->insert(bot_id, qu_result);
// Parsing request params from client_session, only string value is accepted
std::unordered_map request_params;
std::string product = "default";
if (request_doc["request"].HasMember("client_session")) {
std::string client_session = request_doc["request"]["client_session"].GetString();
rapidjson::Document client_session_doc;
if (!client_session_doc.Parse(client_session.c_str()).HasParseError() && client_session_doc.IsObject()) {
for (auto& m_param: client_session_doc.GetObject()) {
if (!m_param.value.IsString()) {
continue;
}
std::string param_name = m_param.name.GetString();
std::string param_value = m_param.value.GetString();
if (param_name == "product") {
product = param_value;
}
request_params[param_name] = param_value;
}
}
}
PolicyOutputSession session = PolicyOutputSession::from_json_str(dm_session);
RequestContext context(this->_remote_service_manager, log_id, request_params);
PolicyOutput* policy_output = this->_policy_manager->resolve(
product, qu_map, session, context);
for (auto iter = qu_map->begin(); iter != qu_map->end(); ++iter) {
QuResult* qu_ptr = iter->second;
if (qu_ptr != nullptr) {
delete qu_ptr;
}
}
delete qu_map;
if (policy_output == nullptr) {
json_response = this->get_error_response(-1, "DM policy resolve failed");
return 0;
}
bool has_query = false;
for (auto const& meta: policy_output->meta) {
if (meta.key == "query") {
has_query = true;
}
}
if (!has_query) {
KVPair meta_query;
meta_query.key = "query";
meta_query.value = query;
policy_output->meta.push_back(meta_query);
}
this->set_dm_response(unit_response_doc, bot_session_doc, policy_output);
delete policy_output;
is_dmkit_response = true;
json_response = utils::json_to_string(unit_response_doc);
return 0;
}
int DialogManager::call_unit_bot(const std::string& access_token,
const std::string& payload,
std::string& result) {
std::string url = "/rpc/2.0/unit/bot/chat?access_token=";
url += access_token;
RemoteServiceParam rsp = {
url,
HTTP_METHOD_POST,
payload
};
RemoteServiceResult rsr;
APP_LOG(TRACE) << "Calling unit bot service, url: "<< url;
APP_LOG(TRACE) << payload;
// unit_bot is a remote service configured in conf/app/remote_services.json
if (this->_remote_service_manager->call("unit_bot", rsp, rsr) !=0) {
APP_LOG(ERROR) << "Failed to get unit bot result" ;
return -1;
}
APP_LOG(TRACE) << "Got unit bot result";
result = rsr.result;
return 0;
}
int DialogManager::handle_unsatisfied_intent(rapidjson::Document& unit_response_doc,
rapidjson::Document& bot_session_doc,
const std::string& dm_session,
std::string& response) {
std::string action_type;
if (unit_response_doc.HasMember("result")
&& unit_response_doc["result"].HasMember("response")
&& unit_response_doc["result"]["response"].HasMember("action_list")
&& unit_response_doc["result"]["response"]["action_list"].Size() > 0
&& unit_response_doc["result"]["response"]["action_list"][0].HasMember("type")) {
action_type = unit_response_doc["result"]["response"]["action_list"][0]["type"].GetString();
} else {
APP_LOG(WARNING) << "Failed to parse action type from unit bot response: "
<< utils::json_to_string(unit_response_doc);
}
if (action_type == "satisfy") {
response = this->get_error_response(-1, "Unsupported action type satisfy");
return 0;
}
if (action_type != "understood") {
// DM session should be saved
rapidjson::Value dm_session_json;
dm_session_json.SetString(
dm_session.c_str(), dm_session.length(), bot_session_doc.GetAllocator());
if (!bot_session_doc.HasMember("dialog_state")) {
rapidjson::Value dialog_state_json(rapidjson::kObjectType);
bot_session_doc.AddMember("dialog_state", dialog_state_json, bot_session_doc.GetAllocator());
}
if (!bot_session_doc["dialog_state"].HasMember("contexts")) {
rapidjson::Value contexts_json(rapidjson::kObjectType);
bot_session_doc["dialog_state"].AddMember("contexts", contexts_json, bot_session_doc.GetAllocator());
}
if (!bot_session_doc["dialog_state"]["contexts"].HasMember("dmkit")) {
rapidjson::Value contexts_json(rapidjson::kObjectType);
bot_session_doc["dialog_state"]["contexts"].AddMember("dmkit", contexts_json, bot_session_doc.GetAllocator());
}
if (bot_session_doc["dialog_state"]["contexts"]["dmkit"].HasMember("session")) {
bot_session_doc["dialog_state"]["contexts"]["dmkit"].RemoveMember("session");
}
bot_session_doc["dialog_state"]["contexts"]["dmkit"].AddMember(
"session", dm_session_json, bot_session_doc.GetAllocator());
std::string bot_session = utils::json_to_string(bot_session_doc);
//unit_response_doc.AddMember("debug", bot_session_doc, unit_response_doc.GetAllocator());
unit_response_doc["result"]["bot_session"].SetString(
bot_session.c_str(), bot_session.length(), unit_response_doc.GetAllocator());
response = utils::json_to_string(unit_response_doc);
return 0;
}
return -1;
}
void DialogManager::set_dm_response(rapidjson::Document& unit_response_doc,
rapidjson::Document& bot_session_doc,
const PolicyOutput* policy_output) {
std::string session_str = PolicyOutputSession::to_json_str(policy_output->session);
rapidjson::Value dm_session;
dm_session.SetString(session_str.c_str(), session_str.length(), bot_session_doc.GetAllocator());
if (!bot_session_doc.HasMember("dialog_state")) {
rapidjson::Value dialog_state_json(rapidjson::kObjectType);
bot_session_doc.AddMember("dialog_state", dialog_state_json, bot_session_doc.GetAllocator());
}
if (!bot_session_doc["dialog_state"].HasMember("contexts")) {
rapidjson::Value contexts_json(rapidjson::kObjectType);
bot_session_doc["dialog_state"].AddMember("contexts", contexts_json, bot_session_doc.GetAllocator());
}
if (!bot_session_doc["dialog_state"]["contexts"].HasMember("dmkit")) {
rapidjson::Value contexts_json(rapidjson::kObjectType);
bot_session_doc["dialog_state"]["contexts"].AddMember("dmkit", contexts_json, bot_session_doc.GetAllocator());
}
if (bot_session_doc["dialog_state"]["contexts"]["dmkit"].HasMember("session")) {
bot_session_doc["dialog_state"]["contexts"]["dmkit"].RemoveMember("session");
}
bot_session_doc["dialog_state"]["contexts"]["dmkit"].AddMember(
"session", dm_session, bot_session_doc.GetAllocator());
// DMKit result as a custom reply
rapidjson::StringBuffer buffer;
rapidjson::Writer writer(buffer);
writer.StartObject();
writer.Key("event_name");
writer.String("DM_RESULT");
writer.Key("result");
std::string policy_output_str = PolicyOutput::to_json_str(*policy_output);
writer.String(policy_output_str.c_str(), policy_output_str.length());
writer.EndObject();
std::string custom_reply = buffer.GetString();
APP_LOG(TRACE) << "custom_reply: " << custom_reply;
bot_session_doc["interactions"][0]["response"]["action_list"][0]["type"] = "event";
bot_session_doc["interactions"][0]["response"]["action_list"][0]["say"] = "";
bot_session_doc["interactions"][0]["response"]["action_list"][0]["custom_reply"].SetString(
custom_reply.c_str(), custom_reply.length(), unit_response_doc.GetAllocator());
unit_response_doc["result"]["response"]["action_list"][0]["type"] = "event";
unit_response_doc["result"]["response"]["action_list"][0]["say"] = "";
unit_response_doc["result"]["response"]["action_list"][0]["custom_reply"].SetString(
custom_reply.c_str(), custom_reply.length(), unit_response_doc.GetAllocator());
std::string bot_session = utils::json_to_string(bot_session_doc);
//unit_response_doc.AddMember("debug", bot_session_doc, unit_response_doc.GetAllocator());
unit_response_doc["result"]["bot_session"].SetString(
bot_session.c_str(), bot_session.length(), unit_response_doc.GetAllocator());
}
std::string DialogManager::get_error_response(int error_code, const std::string& error_msg) {
rapidjson::StringBuffer buffer;
rapidjson::Writer writer(buffer);
writer.StartObject();
writer.Key("error_code");
writer.Int(error_code);
writer.Key("error_msg");
writer.String(error_msg.c_str(), error_msg.length());
writer.EndObject();
return buffer.GetString();
}
void DialogManager::send_json_response(BRPC_NAMESPACE::Controller* cntl,
const std::string& data) {
cntl->http_response().set_content_type("application/json;charset=UTF-8");
this->add_notice_log("ret", data);
cntl->response_attachment().append(data);
}
} // namespace dmkit
================================================
FILE: src/dialog_manager.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "application_base.h"
#include "policy.h"
#include "policy_manager.h"
#include "qu_result.h"
#include "rapidjson.h"
#include "remote_service_manager.h"
#include "token_manager.h"
#ifndef DMKIT_DIALOG_MANAGER_H
#define DMKIT_DIALOG_MANAGER_H
namespace dmkit {
// The Dialog Manager application
class DialogManager : public ApplicationBase {
public:
DialogManager();
virtual ~DialogManager();
virtual int init();
virtual int run(BRPC_NAMESPACE::Controller* cntl);
private:
int process_request(const rapidjson::Document& request_doc,
const std::string& dm_session,
const std::string& access_token,
std::string& json_response,
bool& is_dmkit_response);
int call_unit_bot(const std::string& access_token,
const std::string& payload,
std::string& result);
int handle_unsatisfied_intent(rapidjson::Document& unit_response_doc,
rapidjson::Document& bot_session_doc,
const std::string& dm_session,
std::string& response);
std::string get_error_response(int error_code, const std::string& error_msg);
void send_json_response(BRPC_NAMESPACE::Controller* cntl, const std::string& data);
void set_dm_response(rapidjson::Document& unit_response_doc,
rapidjson::Document& bot_session_doc,
const PolicyOutput* policy_output);
RemoteServiceManager* _remote_service_manager;
PolicyManager* _policy_manager;
TokenManager* _token_manager;
};
} // namespace dmkit
#endif //DMKIT_DIALOG_MANAGER_H
================================================
FILE: src/file_watcher.cpp
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "file_watcher.h"
#include
#include
#include "app_log.h"
#include "utils.h"
namespace dmkit {
const int FileWatcher::CHECK_INTERVAL_IN_MILLS;
FileWatcher& FileWatcher::get_instance() {
static FileWatcher instance;
return instance;
}
FileWatcher::FileWatcher() {
}
FileWatcher::~FileWatcher() {
this->_is_running = false;
if (this->_watcher_thread.joinable()) {
this->_watcher_thread.join();
}
}
static int get_file_last_modified_time(const std::string& file_path, std::string& mtime_str) {
struct stat f_stat;
if (stat(file_path.c_str(), &f_stat) != 0) {
LOG(WARNING) << "Failed to get file modified time" << file_path;
return -1;
}
mtime_str = ctime(&f_stat.st_mtime);
return 0;
}
int FileWatcher::register_file(const std::string file_path,
FileChangeCallback cb,
void* param,
bool level_trigger) {
LOG(TRACE) << "FileWatcher registering file " << file_path;
std::string last_modified_time;
if (get_file_last_modified_time(file_path, last_modified_time) != 0) {
return -1;
}
FileStatus file_status = {file_path, last_modified_time, cb, param, level_trigger};
std::lock_guard lock(this->_mutex);
this->_file_info[file_path] = file_status;
if (!this->_is_running) {
this->_is_running = true;
this->_watcher_thread = std::thread(&FileWatcher::watcher_thread_func, this);
}
return 0;
}
int FileWatcher::unregister_file(const std::string file_path) {
LOG(TRACE) << "FileWatcher unregistering file " << file_path;
std::lock_guard lock(this->_mutex);
if (this->_file_info.erase(file_path) != 1) {
return -1;
}
if (this->_file_info.size() == 0
&& this->_is_running
&& this->_watcher_thread.joinable()) {
this->_is_running = false;
this->_watcher_thread.join();
}
return 0;
}
void FileWatcher::watcher_thread_func() {
LOG(TRACE) << "Watcher thread starting...";
while (this->_is_running) {
{
std::lock_guard lock(this->_mutex);
for (const auto& file : this->_file_info) {
std::string last_modified_time;
get_file_last_modified_time(file.first, last_modified_time);
if (last_modified_time != file.second.last_modified_time) {
LOG(TRACE) << "File Changed. " << file.first << " modified time " << last_modified_time;
if (file.second.callback(file.second.param) == 0
|| !file.second.level_trigger) {
if (this->_file_info.find(file.first) != this->_file_info.end()) {
this->_file_info[file.first].last_modified_time = last_modified_time;
}
}
}
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(FileWatcher::CHECK_INTERVAL_IN_MILLS));
}
LOG(TRACE) << "Watcher thread stopping...";
}
} // namespace dmkit
================================================
FILE: src/file_watcher.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_FILE_WATCHER_H
#define DMKIT_FILE_WATCHER_H
#include
#include
#include
#include
#include
namespace dmkit {
// Callback function when file changed
typedef int (*FileChangeCallback)(void* param);
struct FileStatus {
std::string file_path;
std::string last_modified_time;
FileChangeCallback callback;
void* param;
bool level_trigger;
};
// A file watcher singleton implemention
class FileWatcher {
public:
static FileWatcher& get_instance();
int register_file(const std::string file_path,
FileChangeCallback cb,
void* param,
bool level_trigger=false);
int unregister_file(const std::string file_path);
// Do not need copy constructor and assignment operator for a singleton class
FileWatcher(FileWatcher const&) = delete;
void operator=(FileWatcher const&) = delete;
private:
FileWatcher();
virtual ~FileWatcher();
void watcher_thread_func();
std::mutex _mutex;
std::atomic _is_running;
std::thread _watcher_thread;
std::unordered_map _file_info;
static const int CHECK_INTERVAL_IN_MILLS = 1000;
};
} // namespace dmkit
#endif //DMKIT_FILE_WATCHER_H
================================================
FILE: src/policy.cpp
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "policy.h"
#include "app_log.h"
#include "utils.h"
namespace dmkit {
Policy::Policy(const PolicyTrigger& trigger,
const std::vector& params,
const std::vector& outputs)
: _trigger(trigger), _params(params), _outputs(outputs) {
}
const PolicyTrigger& Policy::trigger() const {
return this->_trigger;
}
const std::vector& Policy::params() const {
return this->_params;
}
const std::vector& Policy::outputs() const {
return this->_outputs;
}
// Parse a policy from json configuration.
// A sample policy is as following:
// {
// "trigger": {
// "intent": "INTENT_XXX",
// "slots": [
// "slot_1",
// "slot_2"
// ],
// "state": "001"
// },
// "params": [
// {
// "name": "param_name",
// "type": "slot_val",
// "value": "slot_tag"
// }
// ],
// "output": [
// {
// "assertion": [
// {
// "type": "eq",
// "value": "1,{%param_name%}"
// }
// ],
// "session": {
// "state": "002"
// "context": {
// "key1": "value1",
// "key2": "value2"
// }
// },
// "meta":{
// "key1": "value1",
// "key2": "value2"
// },
// "result": [
// {
// "type":"tts",
// "value": "Hello World"
// }
// ]
// }
// ]
// }
Policy* Policy::parse_from_json_value(const rapidjson::Value& value) {
if (!value.IsObject()) {
LOG(WARNING) << "Failed to parse policy from json, not an object: " << utils::json_to_string(value);
return nullptr;
}
if (!value.HasMember("trigger") || !value["trigger"].IsObject()
|| !value["trigger"].HasMember("intent") || !value["trigger"]["intent"].IsString()) {
LOG(WARNING) << "Failed to parse policy from json, invalid trigger";
return nullptr;
}
PolicyTrigger trigger;
trigger.intent = value["trigger"]["intent"].GetString();
if (value["trigger"].HasMember("slots") && value["trigger"]["slots"].IsArray()) {
for (auto& v : value["trigger"]["slots"].GetArray()) {
trigger.slots.push_back(v.GetString());
}
}
if (value["trigger"].HasMember("state") && value["trigger"]["state"].IsString()) {
trigger.state = value["trigger"]["state"].GetString();
}
std::vector params;
if (value.HasMember("params") && value["params"].IsArray()) {
for (auto& v : value["params"].GetArray()) {
PolicyParam param;
param.name = v["name"].GetString();
param.type = v["type"].GetString();
param.value = v["value"].GetString();
if (v.HasMember("required") && v["required"].IsBool()) {
param.required = v["required"].GetBool();
} else {
param.required = false;
}
if (v.HasMember("default") && v["default"].IsString()) {
param.default_value = v["default"].GetString();
}
params.push_back(param);
}
}
std::vector outputs;
if (value.HasMember("output") && value["output"].IsArray()) {
for (auto& v : value["output"].GetArray()) {
PolicyOutput output;
if (v.HasMember("assertion") && v["assertion"].IsArray()) {
for (auto& v_assertion: v["assertion"].GetArray()) {
std::string assertion_type = v_assertion["type"].GetString();
std::string assertion_value = v_assertion["value"].GetString();
PolicyOutputAssertion assertion = {assertion_type, assertion_value};
output.assertions.push_back(assertion);
}
}
if (v.HasMember("session") && v["session"].IsObject()) {
if (v["session"].HasMember("state") && v["session"]["state"].IsString()) {
output.session.state = v["session"]["state"].GetString();
}
if (v["session"].HasMember("context") && v["session"]["context"].IsObject()) {
for (auto& m_context: v["session"]["context"].GetObject()) {
std::string context_key = m_context.name.GetString();
std::string context_value = m_context.value.GetString();
KVPair context = {context_key, context_value};
output.session.context.push_back(context);
}
}
}
if (v.HasMember("meta") && v["meta"].IsObject()) {
for (auto& m_meta: v["meta"].GetObject()) {
std::string meta_key = m_meta.name.GetString();
std::string meta_value = m_meta.value.GetString();
KVPair meta = {meta_key, meta_value};
output.meta.push_back(meta);
}
}
for (auto& v_result : v["result"].GetArray()) {
PolicyOutputResult result;
result.type = v_result["type"].GetString();
if (v_result["value"].IsString()) {
result.values.push_back(v_result["value"].GetString());
} else if (v_result["value"].IsArray()) {
for (auto& v_result_value : v_result["value"].GetArray()) {
result.values.push_back(v_result_value.GetString());
}
}
if (v_result.HasMember("extra") && v_result["extra"].IsString()) {
result.extra = v_result["extra"].GetString();
}
output.results.push_back(result);
}
outputs.push_back(output);
}
}
return new Policy(trigger, params, outputs);;
}
std::string PolicyOutputSession::to_json_str(const PolicyOutputSession& session) {
rapidjson::StringBuffer buffer;
rapidjson::Writer writer(buffer);
writer.StartObject();
writer.Key("domain");
writer.String(session.domain.c_str(), session.domain.length());
writer.Key("state");
writer.String(session.state.c_str(), session.state.length());
writer.Key("context");
writer.StartObject();
for (auto const& object: session.context) {
writer.Key(object.key.c_str(), object.key.length());
writer.String(object.value.c_str(), object.value.length());
}
writer.EndObject();
writer.EndObject();
return buffer.GetString();
}
PolicyOutputSession PolicyOutputSession::from_json_str(const std::string& json_str) {
PolicyOutputSession session;
rapidjson::Document session_doc;
if (session_doc.Parse(json_str.c_str()).HasParseError() || !session_doc.IsObject()) {
return session;
}
if (session_doc.HasMember("domain")) {
session.domain = session_doc["domain"].GetString();
}
if (session_doc.HasMember("state")) {
session.state = session_doc["state"].GetString();
}
if (!session_doc.HasMember("context")) {
return session;
}
for (auto& m_object: session_doc["context"].GetObject()) {
std::string context_key = m_object.name.GetString();
std::string context_value = m_object.value.GetString();
KVPair context = {context_key, context_value};
session.context.push_back(context);
}
return session;
}
std::string PolicyOutput::to_json_str(const PolicyOutput& output) {
rapidjson::StringBuffer buffer;
rapidjson::Writer writer(buffer);
writer.StartObject();
writer.Key("meta");
writer.StartObject();
for (auto const& meta: output.meta) {
writer.Key(meta.key.c_str(), meta.key.length());
writer.String(meta.value.c_str(), meta.value.length());
}
writer.EndObject();
writer.Key("result");
writer.StartArray();
for (auto const& result: output.results) {
writer.StartObject();
writer.Key("type");
writer.String(result.type.c_str(), result.type.length());
writer.Key("value");
writer.String(result.values[0].c_str(), result.values[0].length());
if (!result.extra.empty()) {
rapidjson::Document extra_doc;
if (!extra_doc.Parse(result.extra.c_str()).HasParseError() && extra_doc.IsObject()) {
for (auto& v_extra: extra_doc.GetObject()) {
std::string extra_key = v_extra.name.GetString();
if (extra_key == "type" || extra_key == "value") {
LOG(WARNING) << "Unsupported extra key " << extra_key;
}
if (v_extra.value.IsString()) {
std::string extra_value = v_extra.value.GetString();
writer.Key(extra_key.c_str());
writer.String(extra_value.c_str(), extra_value.length());
} else if (v_extra.value.IsBool()) {
writer.Key(extra_key.c_str());
writer.Bool(v_extra.value.GetBool());
} else if (v_extra.value.IsInt()) {
writer.Key(extra_key.c_str());
writer.Int(v_extra.value.GetInt());
} else if (v_extra.value.IsDouble()) {
writer.Key(extra_key.c_str());
writer.Double(v_extra.value.GetDouble());
} else {
LOG(WARNING) << "Unknown extra value type " << v_extra.value.GetType();
}
}
} else {
LOG(WARNING) << "Failed to parse result extra json: " << result.extra;
}
}
writer.EndObject();
}
writer.EndArray();
writer.EndObject();
return buffer.GetString();
}
} // namespace dmkit
================================================
FILE: src/policy.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_POLICY_H
#define DMKIT_POLICY_H
#include
#include
#include "rapidjson.h"
namespace dmkit {
struct KVPair {
std::string key;
std::string value;
};
// A trigger define the condition under which a policy is choosen,
// including intent and slots from qu result, as well as the current dm state.
struct PolicyTrigger {
std::string intent;
std::vector slots;
std::string state;
};
// The parameter required in a policy.
struct PolicyParam {
std::string name;
std::string type;
std::string value;
std::string default_value;
bool required;
};
// Session for policy output, including current domain, user defines contexts
// and a state which DM will move to.
struct PolicyOutputSession {
std::string domain;
std::string state;
std::vector context;
static std::string to_json_str(const PolicyOutputSession& session);
static PolicyOutputSession from_json_str(const std::string& json_str);
};
// A result item of dm output.
struct PolicyOutputResult {
std::string type;
std::vector values;
std::string extra;
};
struct PolicyOutputQuSlot {
std::string key;
std::string value;
std::string normalized_value;
};
// In case the Qu result is required as well, currently not used.
struct PolicyOutputQu {
std::string domain;
std::string intent;
std::vector slots;
};
// An assertion defines the condition under which a result is choosen.
struct PolicyOutputAssertion {
std::string type;
std::string value;
};
// Schema for DMKit output
struct PolicyOutput {
std::vector assertions;
PolicyOutputQu qu;
std::vector meta;
PolicyOutputSession session;
std::vector results;
static std::string to_json_str(const PolicyOutput& output);
};
// A policy defines a processing(params) & response(output),
// given a trigger(intent+slots+state) condition.
class Policy {
public:
Policy(const PolicyTrigger& trigger,
const std::vector& params,
const std::vector& outputs);
const PolicyTrigger& trigger() const;
const std::vector& params() const;
const std::vector& outputs() const;
static Policy* parse_from_json_value(const rapidjson::Value& value);
private:
PolicyTrigger _trigger;
std::vector _params;
std::vector _outputs;
};
} // namespace dmkit
#endif //DMKIT_POLICY_H
================================================
FILE: src/policy_manager.cpp
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "policy_manager.h"
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include "app_log.h"
#include "file_watcher.h"
#include "utils.h"
namespace dmkit {
DomainPolicy::DomainPolicy(const std::string& name, int score, IntentPolicyMap* intent_policy_map)
: _name(name), _score(score), _intent_policy_map(intent_policy_map) {
}
DomainPolicy::~DomainPolicy() {
if (this->_intent_policy_map == nullptr) {
return;
}
for (IntentPolicyMap::iterator iter = this->_intent_policy_map->begin();
iter != this->_intent_policy_map->end(); ++iter) {
auto policy_vector = iter->second;
if (policy_vector == nullptr) {
continue;
}
for (PolicyVector::iterator iter2 = policy_vector->begin();
iter2 != policy_vector->end(); ++iter2) {
auto policy = *iter2;
if (policy == nullptr) {
continue;
}
delete policy;
*iter2 = nullptr;
}
delete policy_vector;
iter->second = nullptr;
}
delete this->_intent_policy_map;
this->_intent_policy_map = nullptr;
}
int DomainPolicy::score() {
return this->_score;
}
const std::string& DomainPolicy::name() {
return this->_name;
}
IntentPolicyMap* DomainPolicy::intent_policy_map() {
return this->_intent_policy_map;
}
PolicyManager::PolicyManager() {
this->_user_function_manager = nullptr;
}
PolicyManager::~PolicyManager() {
if (this->_user_function_manager != nullptr) {
delete this->_user_function_manager;
this->_user_function_manager = nullptr;
}
FileWatcher::get_instance().unregister_file(this->_conf_file_path);
this->_p_policy_dict.reset();
}
static inline void destroy_policy_dict(ProductPolicyMap* policy_dict) {
LOG(TRACE) << "Destroying policy dict";
if (policy_dict == nullptr) {
return;
}
for (ProductPolicyMap::iterator iter = policy_dict->begin(); iter != policy_dict->end();
++iter) {
auto domain_policy_map = iter->second;
if (domain_policy_map == nullptr) {
continue;
}
for (DomainPolicyMap::iterator iter2 = domain_policy_map->begin();
iter2 != domain_policy_map->end(); ++iter2) {
auto domain_policy = iter2->second;
if (domain_policy == nullptr) {
continue;
}
delete domain_policy;
iter2->second = nullptr;
}
delete domain_policy_map;
iter->second = nullptr;
}
delete policy_dict;
policy_dict = nullptr;
}
// Loads policies from JSON configuration files.
int PolicyManager::init(const char* dir_path, const char* conf_file) {
std::string file_path;
if (dir_path != nullptr) {
file_path += dir_path;
}
if (!file_path.empty() && file_path[file_path.length() - 1] != '/') {
file_path += '/';
}
if (conf_file != nullptr) {
file_path += conf_file;
}
this->_conf_file_path = file_path;
ProductPolicyMap* policy_dict = this->load_policy_dict();
if (policy_dict == nullptr) {
APP_LOG(ERROR) << "Failed to init policy dict";
return -1;
}
this->_p_policy_dict.reset(policy_dict, [](ProductPolicyMap* p) { destroy_policy_dict(p); });
FileWatcher::get_instance().register_file(
this->_conf_file_path, PolicyManager::policy_conf_change_callback, this, true);
this->_user_function_manager = new UserFunctionManager();
if (this->_user_function_manager->init() != 0) {
APP_LOG(ERROR) << "Failed to init UserFunctionManager";
return -1;
}
return 0;
}
int PolicyManager::reload() {
LOG(TRACE) << "Reloading policy dict";
ProductPolicyMap * policy_dict = this->load_policy_dict();
if (policy_dict == nullptr) {
LOG(WARNING) << "Cannot reload policy! Policy dict load failed.";
return -1;
}
this->_p_policy_dict.reset(policy_dict, [](ProductPolicyMap* p) { destroy_policy_dict(p); });
APP_LOG(TRACE) << "Reload finished.";
return 0;
}
int PolicyManager::policy_conf_change_callback(void* param) {
PolicyManager* pm = (PolicyManager*)param;
return pm->reload();
}
PolicyOutput* PolicyManager::resolve(const std::string& product,
BUTIL_NAMESPACE::FlatMap* qu_result,
const PolicyOutputSession& session,
const RequestContext& context) {
std::shared_ptr p_policy_map(this->_p_policy_dict);
if (p_policy_map == nullptr) {
APP_LOG(ERROR) << "Policy resolve failed, empty policy dict";
return nullptr;
}
DomainPolicyMap** seek_result = p_policy_map->seek(product);
if (seek_result == nullptr) {
APP_LOG(WARNING) << "unkown product " << product;
return nullptr;
}
DomainPolicyMap* domain_policy_map = *seek_result;
DomainPolicy* state_domain_policy = nullptr;
Policy* state_policy = nullptr;
std::forward_list> ranked_policies;
std::vector empty_slots;
QuResult empty_qu("", "", empty_slots);
std::string request_domain;
context.try_get_param("domain", request_domain);
for (DomainPolicyMap::iterator iter = domain_policy_map->begin();
iter != domain_policy_map->end(); ++iter) {
auto domain_policy = iter->second;
const std::string& domain_name = domain_policy->name();
if (!request_domain.empty() && domain_name != request_domain) {
continue;
}
QuResult** qu_seek_result = qu_result->seek(domain_name);
QuResult* qu = qu_seek_result != nullptr ? *qu_seek_result : &empty_qu;
// Find a best policy give the qu result in current domain
Policy* find_result = this->find_best_policy(domain_policy, qu, session, context);
if (find_result == nullptr) {
continue;
}
// Policy with a domain and trigger state matches current DMKit domain&state is ranked top
if (!session.domain.empty() && (domain_name == session.domain)
&& !session.state.empty() && (find_result->trigger().state == session.state)) {
state_domain_policy = domain_policy;
state_policy = find_result;;
continue;
}
// In case there are multiple domain results, the policies a ranked by static domain score
auto previous_it = ranked_policies.before_begin();
for (auto it = ranked_policies.begin(); it != ranked_policies.end(); ++it) {
if (domain_policy->score() > it->first->score()) {
break;
}
previous_it = it;
}
std::pair p(domain_policy, find_result);
ranked_policies.insert_after(previous_it, p);
}
if (state_domain_policy != nullptr && state_policy != nullptr) {
std::pair p(state_domain_policy, state_policy);
ranked_policies.insert_after(ranked_policies.before_begin(), p);
}
// Resolve policy output and returns the first output resolved successfully.
for (auto& p: ranked_policies) {
const std::string& domain_name = p.first->name();
APP_LOG(TRACE) << "Resolving policy output for domain [" << domain_name << "]";
QuResult** qu_seek_result = qu_result->seek(domain_name);
QuResult* qu = qu_seek_result != nullptr ? *qu_seek_result : &empty_qu;
PolicyOutput* result = this->resolve_policy_output(domain_name,
p.second, qu, session, context);
if (result != nullptr) {
APP_LOG(TRACE) << "Final result domain [" << domain_name << "]";
return result;
}
}
return nullptr;
}
ProductPolicyMap* PolicyManager::load_policy_dict() {
ProductPolicyMap* product_policy_map = new ProductPolicyMap();
// 10: bucket_count, initial count of buckets, big enough to avoid resize.
// 80: load_factor, element_count * 100 / bucket_count.
product_policy_map->init(10, 80);
FILE* fp = fopen(this->_conf_file_path.c_str(), "r");
if (fp == nullptr) {
APP_LOG(ERROR) << "Failed to open file " << this->_conf_file_path;
destroy_policy_dict(product_policy_map);
return nullptr;
}
char read_buffer[1024];
rapidjson::FileReadStream is(fp, read_buffer, sizeof(read_buffer));
rapidjson::Document doc;
doc.ParseStream(is);
fclose(fp);
if (doc.HasParseError() || !doc.IsObject()) {
APP_LOG(ERROR) << "Failed to parse products.json file";
destroy_policy_dict(product_policy_map);
return nullptr;
}
for (rapidjson::Value::ConstMemberIterator prod_iter = doc.MemberBegin();
prod_iter != doc.MemberEnd(); ++prod_iter) {
std::string prod_name = prod_iter->name.GetString();
if (!prod_iter->value.IsObject()) {
APP_LOG(ERROR) << "Invalid product conf for " << prod_name;
destroy_policy_dict(product_policy_map);
return nullptr;
}
DomainPolicyMap* domain_policy_map = this->load_domain_policy_map(prod_name, prod_iter->value);
if (domain_policy_map == nullptr) {
APP_LOG(ERROR) << "Failed to load policies for product " << prod_name;
destroy_policy_dict(product_policy_map);
return nullptr;
}
product_policy_map->insert(prod_name, domain_policy_map);
}
return product_policy_map;
}
DomainPolicyMap* PolicyManager::load_domain_policy_map(const std::string& product_name,
const rapidjson::Value& product_json) {
DomainPolicyMap* domain_policy_map = new DomainPolicyMap();
// 10: bucket_count, initial count of buckets, big enough to avoid resize.
// 80: load_factor, element_count * 100 / bucket_count.
domain_policy_map->init(10, 80);
APP_LOG(TRACE) << "Loading policies for product: " << product_name;
for (rapidjson::Value::ConstMemberIterator domain_iter = product_json.MemberBegin();
domain_iter != product_json.MemberEnd(); ++domain_iter) {
std::string domain_name = domain_iter->name.GetString();
const rapidjson::Value& domain_json = domain_iter->value;
rapidjson::Value::ConstMemberIterator setting_iter;
setting_iter = domain_json.FindMember("score");
if (setting_iter == domain_json.MemberEnd() || !setting_iter->value.IsInt()) {
APP_LOG(WARNING) << "Failed to parse score for domain "
<< domain_name << " in product " << product_name << ", skipped";
continue;
}
int score = setting_iter->value.GetInt();
setting_iter = domain_json.FindMember("conf_path");
if (setting_iter == domain_json.MemberEnd() || !setting_iter->value.IsString()) {
APP_LOG(WARNING) << "Failed to parse conf_path for domain "
<< domain_name << " in product " << product_name << ", skipped";
continue;
}
std::string conf_path = setting_iter->value.GetString();
APP_LOG(TRACE) << "Loading policies for domain " << domain_name << " from " << conf_path;
DomainPolicy* domain_policy = this->load_domain_policy(domain_name, score, conf_path);
if (domain_policy == nullptr) {
APP_LOG(WARNING) << "Failed to load policy for domain "
<< domain_name << " in product " << product_name << ", skipped";
continue;
}
APP_LOG(TRACE) << "Loaded policies for domain " << domain_name;
domain_policy_map->insert(domain_name, domain_policy);
}
return domain_policy_map;
}
DomainPolicy* PolicyManager::load_domain_policy(const std::string& domain_name,
int score,
const std::string& conf_path) {
FILE* fp = fopen(conf_path.c_str(), "r");
if (fp == nullptr) {
APP_LOG(ERROR) << "Failed to open file " << conf_path;
return nullptr;
}
char read_buffer[1024];
rapidjson::FileReadStream is(fp, read_buffer, sizeof(read_buffer));
rapidjson::Document doc;
doc.ParseStream(is);
fclose(fp);
if (doc.HasParseError() || !doc.IsArray()) {
APP_LOG(ERROR) << "Failed to parse domain conf " << conf_path;
return nullptr;
}
IntentPolicyMap* intent_policy_map = new IntentPolicyMap();
// 10: bucket_count, initial count of buckets, big enough to avoid resize.
// 80: load_factor, element_count * 100 / bucket_count.
intent_policy_map->init(10, 80);
for (rapidjson::Value::ConstValueIterator policy_iter = doc.Begin();
policy_iter != doc.End(); ++policy_iter) {
APP_LOG(TRACE) << "loading policy...";
Policy* policy = Policy::parse_from_json_value(*policy_iter);
if (policy == nullptr) {
APP_LOG(WARNING) << "Found invalid policy conf in path " << conf_path << ", skipped";
continue;
}
const std::string& trigger_intent = policy->trigger().intent;
if (intent_policy_map->seek(trigger_intent) == nullptr) {
intent_policy_map->insert(trigger_intent, new PolicyVector);
}
(*intent_policy_map)[trigger_intent]->push_back(policy);
}
APP_LOG(TRACE) << "initializing domain policy...";
DomainPolicy* domain_policy = new DomainPolicy(domain_name, score, intent_policy_map);
APP_LOG(TRACE) << "finish initializing domain policy...";
return domain_policy;
}
// When multiple policies satisfy the current intent,
// the following strategy is applies to find the best one:
// 1. The policies with none empty trigger state, it should match current DM state.
// 2. Policy with maximum number of matched trigger slot is ranked top.
static Policy* find_best_policy_from_candidates(PolicyVector& policy_vector,
std::string& state,
std::unordered_multiset& qu_slot_set) {
Policy* policy_result = nullptr;
for (auto const& policy: policy_vector) {
if (!policy->trigger().state.empty() && policy->trigger().state != state) {
continue;
}
bool missing_slot = false;
std::unordered_multiset trigger_slot_set;
for (auto const& slot: policy->trigger().slots) {
trigger_slot_set.insert(slot);
}
for (auto iter = trigger_slot_set.begin(); iter != trigger_slot_set.end();) {
int trigger_slot_cnt = trigger_slot_set.count(*iter);
int qu_slot_cnt = qu_slot_set.count(*iter);
if (qu_slot_cnt < trigger_slot_cnt) {
missing_slot = true;
break;
}
std::advance(iter, trigger_slot_cnt);
}
if (missing_slot) {
continue;
}
if (policy_result == nullptr) {
policy_result = policy;
continue;
}
if (!policy_result->trigger().state.empty() && policy->trigger().state.empty()) {
continue;
}
if (policy_result->trigger().state.empty() && !policy->trigger().state.empty()) {
policy_result = policy;
continue;
}
if (policy_result->trigger().slots.size() < policy->trigger().slots.size()) {
policy_result = policy;
}
}
return policy_result;
}
Policy* PolicyManager::find_best_policy(DomainPolicy* domain_policy,
QuResult* qu_result,
const PolicyOutputSession& session,
const RequestContext& context) {
(void) context;
std::string state;
if (domain_policy->name() == session.domain) {
state = session.state;
}
std::unordered_multiset qu_slot_set;
IntentPolicyMap* intent_policy_map = domain_policy->intent_policy_map();
PolicyVector** policy_vector_seek = nullptr;
Policy* policy_result = nullptr;
for (auto const& slot: qu_result->slots()) {
qu_slot_set.insert(slot.key());
}
// Policy with matching intent.
const std::string& intent = qu_result->intent();
policy_vector_seek = intent_policy_map->seek(intent);
if (policy_vector_seek != nullptr) {
APP_LOG(TRACE) << "intent [" << intent << "] candidate count [" << (*policy_vector_seek)->size() << "]";
policy_result = find_best_policy_from_candidates(**policy_vector_seek, state, qu_slot_set);
}
// Fallback policy when none of the policies match intent.
if (policy_result == nullptr) {
const std::string fallback_intent = "dmkit_intent_fallback";
policy_vector_seek = intent_policy_map->seek(fallback_intent);
if (policy_vector_seek != nullptr) {
APP_LOG(TRACE) << "dmkit_intent_fallback candidate count [" << (*policy_vector_seek)->size() << "]";
policy_result = find_best_policy_from_candidates(**policy_vector_seek, state, qu_slot_set);
}
}
return policy_result;
}
// Resolve a string with params in it.
static bool try_resolve_params(std::string& unresolved,
const std::unordered_map& param_map) {
if (unresolved.empty()) {
return true;
}
std::string resolved;
bool is_param = false;
unsigned int last_index = 0;
for (unsigned int i = 0; i < unresolved.length(); ++i) {
if (unresolved[i] == '{' && i + 1 < unresolved.length() && unresolved[i + 1] == '%' ) {
resolved += unresolved.substr(last_index, i - last_index);
last_index = i + 2;
is_param = true;
++i;
}
if (unresolved[i] == '%' && i + 1 < unresolved.length() && unresolved[i + 1] == '}') {
if (!is_param) {
APP_LOG(WARNING) << "Cannot resolve params in string, invalid format. " << unresolved;
return false;
}
std::string param_name = unresolved.substr(last_index, i - last_index);
std::unordered_map::const_iterator find_res = param_map.find(param_name);
if (find_res == param_map.end()) {
APP_LOG(WARNING) << "Cannot resolve params in string, unknow param. "
<< unresolved << " " << param_name;
return false;
}
resolved += find_res->second;
last_index = i + 2;
is_param = false;
++i;
}
}
if (is_param) {
APP_LOG(WARNING) << "Cannot resolve params in string, invalid format. " << unresolved;
return false;
}
if (last_index < unresolved.length()) {
resolved += unresolved.substr(last_index);
}
unresolved = resolved;
return true;
}
// Resolve a string with delimiter and params in it
static bool try_resolve_param_list(const std::string& unresolved,
const char delimiter,
const std::unordered_map& param_map,
std::vector &result) {
std::size_t pos = 0;
std::size_t last_pos = 0;
result.clear();
while (last_pos < unresolved.length() && (pos = unresolved.find(delimiter, last_pos)) != std::string::npos) {
std::string part = unresolved.substr(last_pos, pos - last_pos);
utils::trim(part);
if (!try_resolve_params(part, param_map)) {
result.clear();
return false;
}
result.push_back(part);
last_pos = pos + 1;
}
if (last_pos < unresolved.length()) {
std::string part = unresolved.substr(last_pos);
utils::trim(part);
if (!try_resolve_params(part, param_map)) {
result.clear();
return false;
}
result.push_back(part);
}
return true;
}
PolicyOutput* PolicyManager::resolve_policy_output(const std::string& domain,
Policy* policy,
QuResult* qu_result,
const PolicyOutputSession& session,
const RequestContext& context) {
// Process parameters
std::unordered_map param_map;
// Default parameters
for (auto const& context: session.context) {
if (context.key == "dmkit_param_last_tts") {
param_map[context.key] = context.value;
continue;
}
std::string param_key = "dmkit_param_context_";
param_key += context.key;
param_map[param_key] = context.value;
}
for (auto const& slot: qu_result->slots()) {
std::string param_key = "dmkit_param_slot_";
param_key += slot.key();
if (param_map.find(param_key) != param_map.end()) {
continue;
}
std::string param_value = slot.normalized_value();
if (param_value.empty()) {
param_value = slot.value();
}
param_map[param_key] = param_value;
}
for (auto const& param: policy->params()) {
APP_LOG(TRACE) << "resolving parameter [" << param.name << "]";
std::string value;
if (param.type == "slot_val" || param.type == "slot_val_ori") {
bool success = false;
std::vector args;
utils::split(param.value, ',', args);
int index = 0;
if (args.size() >= 2 && !utils::try_atoi(args[1], index)) {
APP_LOG(WARNING) << "Invalid index for slot_val parameter: " << param.value;
index = -1;
}
for (auto const& slot: qu_result->slots()) {
if (index < 0) {
break;
}
if (slot.key() == args[0]) {
if (index > 0) {
index--;
continue;
}
if (param.type == "slot_val" && !slot.normalized_value().empty()) {
value = slot.normalized_value();
success = true;
break;
}
value = slot.value();
success = true;
break;
}
}
if (!success) {
if (param.required) {
return nullptr;
}
value = param.default_value;
}
} else if (param.type == "qu_intent") {
value = qu_result->intent();
} else if (param.type == "session_state") {
value = session.state;
} else if (param.type == "session_context") {
bool success = false;
for (auto const& obj: session.context) {
if (obj.key == param.value) {
value = obj.value;
success = true;
break;
}
}
if (!success) {
if (param.required) {
return nullptr;
}
value = param.default_value;
}
}else if (param.type == "const") {
value = param.value;
} else if (param.type == "string") {
value = param.value;
if (!try_resolve_params(value, param_map)) {
if (param.required) {
return nullptr;
}
value = param.default_value;
}
} else if (param.type == "request_param") {
const std::unordered_map request_params = context.params();
std::unordered_map::const_iterator find_res
= request_params.find(param.value);
bool success = false;
if (find_res != request_params.end()) {
value = find_res->second;
success = true;
}
if (!success) {
if (param.required) {
return nullptr;
}
value = param.default_value;
}
} else if (param.type == "func_val") {
std::string func_val = param.value;
std::string func_name;
std::vector args;
std::size_t pos = func_val.find(':');
bool has_error = false;
if (pos == std::string::npos || pos == func_val.length() - 1) {
func_name = func_val;
if (!try_resolve_params(func_name, param_map)) {
has_error = true;
}
} else {
func_name = func_val.substr(0, pos);
if (!try_resolve_params(func_name, param_map)) {
has_error = true;;
}
std::string arg_list = func_val.substr(pos + 1);
if (!try_resolve_param_list(arg_list, ',', param_map, args)) {
has_error = true;
}
}
utils::trim(func_name);
if (has_error || this->_user_function_manager->call_user_function(func_name, args, context, value) != 0) {
has_error = true;
}
if (has_error) {
if (param.required) {
return nullptr;
}
value = param.default_value;
}
} else {
APP_LOG(WARNING) << "Unknown param type " << param.type;
if (param.required) {
return nullptr;
}
value = param.default_value;
}
param_map[param.name] = value;
APP_LOG(TRACE) << "Parameter value [" << value << "]";
}
int selected_output_index = -1;
for (unsigned int i = 0; i < policy->outputs().size(); ++i) {
bool failed = false;
// Process assertions
APP_LOG(TRACE) << "Candidate output size [" << policy->outputs().size() << "]";
for (unsigned int j = 0; j < policy->outputs()[i].assertions.size(); ++j) {
std::string assertion_type = policy->outputs()[i].assertions[j].type;
std::string assertion_value = policy->outputs()[i].assertions[j].value;
if (!try_resolve_params(assertion_value, param_map)) {
failed = true;
break;
}
APP_LOG(TRACE) << "evaluating assertion, type[" << assertion_type << "] value[" << assertion_value << "]";
if (assertion_type == "not_empty") {
if (assertion_value.empty()) {
failed = true;
break;
}
} else if (assertion_type == "empty") {
if (!assertion_value.empty()) {
failed = true;
break;
}
} else if (assertion_type == "in") {
std::vector value_list;
bool has_match = false;
if (try_resolve_param_list(assertion_value, ',', param_map, value_list)) {
for (unsigned int k = 1; k < value_list.size(); k++) {
if (value_list[0] == value_list[k]) {
has_match = true;
break;
}
}
}
if (!has_match) {
failed = true;
}
} else if (assertion_type == "not_in") {
std::vector value_list;
if (try_resolve_param_list(assertion_value, ',', param_map, value_list)) {
for (unsigned int k = 1; k < value_list.size(); k++) {
if (value_list[0] == value_list[k]) {
failed = true;
break;
}
}
}
} else if (assertion_type == "eq") {
std::vector value_list;
if (!try_resolve_param_list(assertion_value, ',', param_map, value_list)
|| value_list.size() < 2
|| value_list[0] != value_list[1]) {
failed = true;
}
} else if (assertion_type == "gt") {
std::vector value_list;
double left_val = 0;
double right_val = 0;
if (!try_resolve_param_list(assertion_value, ',', param_map, value_list)
|| value_list.size() < 2
|| !utils::try_atof(value_list[0], left_val)
|| !utils::try_atof(value_list[1], right_val)
|| left_val <= right_val) {
failed = true;
}
} else if (assertion_type == "ge") {
std::vector value_list;
double left_val = 0;
double right_val = 0;
if (!try_resolve_param_list(assertion_value, ',', param_map, value_list)
|| value_list.size() < 2
|| !utils::try_atof(value_list[0], left_val)
|| !utils::try_atof(value_list[1], right_val)
|| left_val < right_val) {
failed = true;
}
} else {
APP_LOG(WARNING) << "Unknown assertion type " << assertion_type;
failed = true;
break;
}
}
if (!failed) {
LOG(TRACE) << "selected output at index [" << i << "]";
selected_output_index = i;
break;
}
}
if (selected_output_index == -1) {
return nullptr;
}
PolicyOutput output;
output.meta = policy->outputs()[selected_output_index].meta;
output.session = policy->outputs()[selected_output_index].session;
output.qu = policy->outputs()[selected_output_index].qu;
output.results = policy->outputs()[selected_output_index].results;
for (unsigned int i = 0; i < output.meta.size(); ++i) {
if (!try_resolve_params(output.meta[i].key, param_map)) {
return nullptr;
}
if (!try_resolve_params(output.meta[i].value, param_map)) {
return nullptr;
}
}
if (!try_resolve_params(output.session.domain, param_map)) {
return nullptr;
}
if (!try_resolve_params(output.session.state, param_map)) {
return nullptr;
}
for (unsigned int i = 0; i < output.session.context.size(); ++i) {
if (!try_resolve_params(output.session.context[i].key, param_map)) {
return nullptr;
}
if (!try_resolve_params(output.session.context[i].value, param_map)) {
return nullptr;
}
}
if (!try_resolve_params(output.qu.domain, param_map)) {
return nullptr;
}
if (!try_resolve_params(output.qu.intent, param_map)) {
return nullptr;
}
for (unsigned int i = 0; i < output.qu.slots.size(); ++i) {
if (!try_resolve_params(output.qu.slots[i].key, param_map)) {
return nullptr;
}
if (!try_resolve_params(output.qu.slots[i].value, param_map)) {
return nullptr;
}
if (!try_resolve_params(output.qu.slots[i].normalized_value, param_map)) {
return nullptr;
}
}
for (unsigned int i = 0; i < output.results.size(); ++i) {
if (output.results[i].values.empty()) {
APP_LOG(WARNING) << "empty result value!";
return nullptr;
}
if (output.results[i].values.size() > 1) {
int size = output.results[i].values.size();
int index = std::time(nullptr) % size;
if (index < 0 || index >= size) {
index = 0;
}
std::string value = output.results[i].values[index];
output.results[i].values.clear();
output.results[i].values.push_back(value);
}
if (!try_resolve_params(output.results[i].values[0], param_map)) {
return nullptr;
}
}
PolicyOutput* output_ptr = new PolicyOutput();
*output_ptr = output;
output_ptr->session.domain = domain;
// Saved parameters
for (auto const& param: param_map) {
if (param.first.find("dmkit_param_context_") != 0) {
continue;
}
std::string context_key = param.first.substr(20);
if (context_key.empty()) {
continue;
}
KVPair context = {context_key, param.second};
output_ptr->session.context.push_back(context);
}
std::string first_tts;
for (unsigned int i = 0; i < output_ptr->results.size(); ++i) {
if (output_ptr->results[i].type == "tts") {
first_tts = output_ptr->results[i].values[0];
break;
}
}
KVPair last_tts_context = {"dmkit_param_last_tts", first_tts};
output_ptr->session.context.push_back(last_tts_context);
return output_ptr;
}
} // namespace dmkit
================================================
FILE: src/policy_manager.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_POLICY_MANAGER_H
#define DMKIT_POLICY_MANAGER_H
#include
#include
#include
#include
#include
#include "butil.h"
#include "policy.h"
#include "qu_result.h"
#include "request_context.h"
#include "user_function_manager.h"
namespace dmkit {
typedef std::vector PolicyVector;
typedef BUTIL_NAMESPACE::FlatMap IntentPolicyMap;
// Holds all policies given a domain.
class DomainPolicy {
public:
DomainPolicy(const std::string& name, int score, IntentPolicyMap* intent_policy_map);
~DomainPolicy();
const std::string& name();
int score();
// Maps a intent to a vector of policies
IntentPolicyMap* intent_policy_map();
private:
std::string _name;
int _score;
IntentPolicyMap* _intent_policy_map;
};
typedef BUTIL_NAMESPACE::FlatMap DomainPolicyMap;
typedef BUTIL_NAMESPACE::FlatMap ProductPolicyMap;
class PolicyManager {
public:
PolicyManager();
~PolicyManager();
int init(const char* dir_path, const char* conf_file);
int reload();
// Resolve a policy output given a qu result and current dm session
PolicyOutput* resolve(const std::string& product,
BUTIL_NAMESPACE::FlatMap* qu_result,
const PolicyOutputSession& session,
const RequestContext& context);
static int policy_conf_change_callback(void* param);
private:
DomainPolicyMap* load_domain_policy_map(const std::string& product_name,
const rapidjson::Value& product_json);
DomainPolicy* load_domain_policy(const std::string& domain_name,
int score,
const std::string& conf_path);
Policy* find_best_policy(DomainPolicy* domain_policy,
QuResult* qu_result,
const PolicyOutputSession& session,
const RequestContext& context);
PolicyOutput* resolve_policy_output(const std::string& domain,
Policy* policy,
QuResult* qu_result,
const PolicyOutputSession& session,
const RequestContext& context);
ProductPolicyMap* load_policy_dict();
std::string _conf_file_path;
std::shared_ptr _p_policy_dict;
UserFunctionManager* _user_function_manager;
};
} // namespace dmkit
#endif //DMKIT_POLICY_MANAGER_H
================================================
FILE: src/qu_result.cpp
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "qu_result.h"
#include "app_log.h"
#include "utils.h"
namespace dmkit {
Slot::Slot(const std::string& key, const std::string& value, const std::string& normalized_value)
: _key(key), _value(value), _normalized_value(normalized_value) {
}
const std::string& Slot::key() const {
return this->_key;
}
const std::string& Slot::value() const {
return this->_value;
}
const std::string& Slot::normalized_value() const {
return this->_normalized_value;
}
QuResult::QuResult(const std::string& domain,
const std::string& intent,
const std::vector& slots)
: _domain(domain), _intent(intent), _slots(slots) {
}
// Parse QU result from dialog state in unit response
QuResult* QuResult::parse_from_dialog_state(const std::string& domain,
const rapidjson::Value& value) {
if (!value.IsObject()) {
APP_LOG(WARNING) << "Failed to parse qu result from json, not an object";
return nullptr;
}
if (!value.HasMember("intents") || !value["intents"].IsArray() || value["intents"].Size() < 1) {
APP_LOG(WARNING) << "Failed to parse qu result from json";
}
int intents_size = value["intents"].Size();
std::string intent = value["intents"][intents_size - 1]["name"].GetString();
std::vector slots;
for (auto& m_slot: value["user_slots"].GetObject()) {
std::string tag = m_slot.name.GetString();
for (auto& m_value: m_slot.value.GetObject()["values"].GetObject()) {
int slot_state = m_value.value.GetObject()["state"].GetInt();
// Slot state possible values are
// 0: slot not filled
// 1: slot is filled by default value
// 2: slot is filled by SLU
// 4: slot was filled but has been replaced by other value
if (slot_state == 0 || slot_state == 4) {
continue;
}
std::string normalized_value = m_value.name.GetString();
std::string value = m_value.value.GetObject()["original_name"].GetString();
Slot slot(tag, value, normalized_value);
slots.push_back(slot);
}
}
QuResult* result = new QuResult(domain, intent, slots);
APP_LOG(TRACE) << result->to_string();
return result;
}
const std::string& QuResult::domain() const {
return this->_domain;
}
const std::string& QuResult::intent() const {
return this->_intent;
}
const std::vector& QuResult::slots() const {
return this->_slots;
}
std::string QuResult::to_string() const {
std::string result;
result += "domain:";
result += this->domain();
result += " intent:";
result += this->intent();
result += " slots: {";
for (auto iter = this->slots().begin(); iter != this->slots().end(); ++iter) {
result += iter->key();
result += ":";
result += iter->value();
if (!iter->normalized_value().empty()) {
result += "(";
result += iter->normalized_value();
result += ")";
}
result += " ";
}
result += "}";
return result;
}
} // namespace dmkit
================================================
FILE: src/qu_result.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_QU_RESULT_H
#define DMKIT_QU_RESULT_H
#include
#include
#include "rapidjson.h"
namespace dmkit {
class Slot {
public:
Slot(const std::string& key, const std::string& value, const std::string& normalized_value);
const std::string& key() const;
const std::string& value() const;
const std::string& normalized_value() const;
private:
std::string _key;
std::string _value;
std::string _normalized_value;
};
// Query Understanding, or natural language understanding (NLU) results for user input,
// generally includes domain, intent and slots.
class QuResult {
public:
QuResult(const std::string& domain, const std::string& intent, const std::vector& slots);
static QuResult* parse_from_dialog_state(const std::string& domain,
const rapidjson::Value& value);
const std::string& domain() const;
const std::string& intent() const;
const std::vector& slots() const;
std::string to_string() const;
private:
std::string _domain;
std::string _intent;
std::vector _slots;
};
} // namespace dmkit
#endif //DMKIT_QU_RESULT_H
================================================
FILE: src/rapidjson.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_RAPIDJSON_H
#define DMKIT_RAPIDJSON_H
#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
#endif
#include "thirdparty/rapidjson/allocators.h"
#include "thirdparty/rapidjson/document.h"
#include "thirdparty/rapidjson/encodedstream.h"
#include "thirdparty/rapidjson/encodings.h"
#include "thirdparty/rapidjson/filereadstream.h"
#include "thirdparty/rapidjson/filewritestream.h"
#include "thirdparty/rapidjson/fwd.h"
#include "thirdparty/rapidjson/istreamwrapper.h"
#include "thirdparty/rapidjson/memorybuffer.h"
#include "thirdparty/rapidjson/memorystream.h"
#include "thirdparty/rapidjson/ostreamwrapper.h"
#include "thirdparty/rapidjson/pointer.h"
#include "thirdparty/rapidjson/prettywriter.h"
#include "thirdparty/rapidjson/rapidjson.h"
#include "thirdparty/rapidjson/reader.h"
#include "thirdparty/rapidjson/schema.h"
#include "thirdparty/rapidjson/stream.h"
#include "thirdparty/rapidjson/stringbuffer.h"
#include "thirdparty/rapidjson/writer.h"
#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)
#pragma GCC diagnostic pop
#endif
#endif //DMKIT_RAPIDJSON_H
================================================
FILE: src/remote_service_manager.cpp
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "remote_service_manager.h"
#include
#include
#include
#include "app_log.h"
#include "file_watcher.h"
#include "rapidjson.h"
#include "thread_data_base.h"
namespace dmkit {
RemoteServiceManager::RemoteServiceManager() {
}
RemoteServiceManager::~RemoteServiceManager() {
FileWatcher::get_instance().unregister_file(this->_conf_file_path);
this->_p_channel_map.reset();
}
static inline void destroy_channel_map(ChannelMap* p) {
APP_LOG(TRACE) << "Destroying service map...";
if (nullptr == p) {
return;
}
for (auto& channel : *p) {
if (nullptr != channel.second.channel) {
delete channel.second.channel;
channel.second.channel = nullptr;
APP_LOG(TRACE) << "Destroyed service" << channel.first;
}
}
delete p;
}
int RemoteServiceManager::init(const char* path, const char* conf) {
std::string file_path;
if (path != nullptr) {
file_path += path;
}
if (!file_path.empty() && file_path[file_path.length() - 1] != '/') {
file_path += '/';
}
if (conf != nullptr) {
file_path += conf;
}
this->_conf_file_path = file_path;
ChannelMap* channel_map = this->load_channel_map();
if (channel_map == nullptr) {
APP_LOG(ERROR) << "Failed to init RemoteServiceManager, cannot load channel map";
return -1;
}
this->_p_channel_map.reset(channel_map,
[](ChannelMap* p) { destroy_channel_map(p); });
FileWatcher::get_instance().register_file(
this->_conf_file_path, RemoteServiceManager::service_conf_change_callback, this, true);
return 0;
}
int RemoteServiceManager::reload() {
APP_LOG(TRACE) << "Reloading RemoteServiceManager...";
ChannelMap* channel_map = this->load_channel_map();
if (channel_map == nullptr) {
APP_LOG(ERROR) << "Failed to reload RemoteServiceManager, cannot load channel map";
return -1;
}
this->_p_channel_map.reset(channel_map,
[](ChannelMap* p) { destroy_channel_map(p); });
APP_LOG(TRACE) << "Reload finished.";
return 0;
}
int RemoteServiceManager::service_conf_change_callback(void* param) {
RemoteServiceManager* rsm = (RemoteServiceManager*)param;
return rsm->reload();
}
int RemoteServiceManager::call(const std::string& service_name,
const RemoteServiceParam& params,
RemoteServiceResult& result) const {
std::shared_ptr p_channel_map(this->_p_channel_map);
if (p_channel_map == nullptr) {
APP_LOG(ERROR) << "Remote service call failed, channel map is null";
return -1;
}
if (p_channel_map->find(service_name) == p_channel_map->end()) {
APP_LOG(ERROR) << "Remote service call failed, cannot find service " << service_name;
return -1;
}
RemoteServiceChannel& service_channel = (*p_channel_map)[service_name];
APP_LOG(TRACE) << "Calling service " << service_name;
int ret = 0;
std::string remote_side;
int latency = 0;
if (service_channel.protocol == "http") {
if (service_channel.channel != nullptr) {
ret = this->call_http_by_BRPC_NAMESPACE(service_channel.channel,
params.url,
params.http_method,
service_channel.headers,
params.payload,
result.result,
remote_side,
latency);
} else {
ret = this->call_http_by_curl(params.url,
params.http_method,
service_channel.headers,
params.payload,
service_channel.timeout_ms,
service_channel.max_retry,
result.result,
remote_side,
latency);
}
} else {
APP_LOG(ERROR) << "Remote service call failed. Unknown protocol" << service_channel.protocol;
ret = -1;
}
std::string log_str;
log_str += "remote:";
log_str += remote_side;
log_str += "|tm:";
log_str += std::to_string(latency);
log_str += "|ret:";
log_str += std::to_string(ret);
std::string log_key = "service_";
log_key += service_name;
APP_LOG(TRACE) << "remote_side=" << remote_side << ", cost=" << latency;
ThreadDataBase* tls = static_cast(BRPC_NAMESPACE::thread_local_data());
// All backend requests are logged.
tls->add_notice_log(log_key, log_str);
return ret;
}
ChannelMap* RemoteServiceManager::load_channel_map() {
APP_LOG(TRACE) << "Loading channel map...";
FILE* fp = fopen(this->_conf_file_path.c_str(), "r");
if (fp == nullptr) {
APP_LOG(ERROR) << "Failed to open file " << this->_conf_file_path;
return nullptr;
}
char read_buffer[1024];
rapidjson::FileReadStream is(fp, read_buffer, sizeof(read_buffer));
rapidjson::Document doc;
doc.ParseStream(is);
fclose(fp);
if (doc.HasParseError() || !doc.IsObject()) {
APP_LOG(ERROR) << "Failed to parse RemoteServiceManager settings";
return nullptr;
}
ChannelMap* channel_map = new ChannelMap();
rapidjson::Value::ConstMemberIterator service_iter;
for (service_iter = doc.MemberBegin(); service_iter != doc.MemberEnd(); ++service_iter) {
// Service name as the key for the channel.
std::string service_name = service_iter->name.GetString();
if (!service_iter->value.IsObject()) {
APP_LOG(ERROR) << "Invalid service settings for " << service_name
<< ", expecting type object for service setting.";
destroy_channel_map(channel_map);
return nullptr;
}
const rapidjson::Value& settings = service_iter->value;
rapidjson::Value::ConstMemberIterator setting_iter;
// Naming service url such as https://www.baidu.com.
// All supported url format can be found in BRPC docs.
setting_iter = settings.FindMember("naming_service_url");
if (setting_iter == settings.MemberEnd() || !setting_iter->value.IsString()) {
APP_LOG(ERROR) << "Invalid service settings for " << service_name
<< ", expecting type String for property naming_service_url.";
destroy_channel_map(channel_map);
return nullptr;
}
std::string naming_service_url = setting_iter->value.GetString();
// Load balancer name such as random, rr.
// All supported balancer can be found in BRPC docs.
setting_iter = settings.FindMember("load_balancer_name");
if (setting_iter == settings.MemberEnd() || !setting_iter->value.IsString()) {
APP_LOG(ERROR) << "Invalid service settings for " << service_name
<< ", expecting type String for property load_balancer_name.";
destroy_channel_map(channel_map);
return nullptr;
}
std::string load_balancer_name = setting_iter->value.GetString();
// Protocol for the channel.
// Currently we support http.
setting_iter = settings.FindMember("protocol");
if (setting_iter == settings.MemberEnd() || !setting_iter->value.IsString()) {
APP_LOG(ERROR) << "Invalid service settings for " << service_name
<< ", expecting type String for property protocol.";
destroy_channel_map(channel_map);
return nullptr;
}
std::string protocol = setting_iter->value.GetString();
// Client to use for sending request.
setting_iter = settings.FindMember("client");
std::string client;
if (setting_iter != settings.MemberEnd() && setting_iter->value.IsString()) {
client = setting_iter->value.GetString();
}
// Timeout value in millisecond.
setting_iter = settings.FindMember("timeout_ms");
if (setting_iter == settings.MemberEnd() || !setting_iter->value.IsInt()) {
APP_LOG(ERROR) << "Invalid service settings for " << service_name
<< ", expecting type Int for property timeout_ms.";
destroy_channel_map(channel_map);
return nullptr;
}
int timeout_ms = setting_iter->value.GetInt();
// Retry count.
setting_iter = settings.FindMember("retry");
if (setting_iter == settings.MemberEnd() || !setting_iter->value.IsInt()) {
APP_LOG(ERROR) << "Invalid service settings for " << service_name
<< ", expecting type Int for property retry.";
destroy_channel_map(channel_map);
return nullptr;
}
int retry = setting_iter->value.GetInt();
// Headers for a http request.
std::vector> headers;
setting_iter = settings.FindMember("headers");
if (setting_iter != settings.MemberEnd() && setting_iter->value.IsObject()) {
const rapidjson::Value& obj_headers = setting_iter->value;
rapidjson::Value::ConstMemberIterator header_iter;
for (header_iter = obj_headers.MemberBegin();
header_iter != obj_headers.MemberEnd(); ++header_iter) {
std::string header_key = header_iter->name.GetString();
if (!header_iter->value.IsString()) {
APP_LOG(ERROR) << "Invalid header value for " << header_key
<< ", expecting type String for header value.";
destroy_channel_map(channel_map);
return nullptr;
}
std::string header_value = header_iter->value.GetString();
headers.push_back(std::make_pair(header_key, header_value));
}
}
BRPC_NAMESPACE::Channel* rpc_channel = nullptr;
if (protocol == "http") {
if (client.empty() || client == "brpc") {
rpc_channel = new BRPC_NAMESPACE::Channel();
BRPC_NAMESPACE::ChannelOptions options;
options.protocol = BRPC_NAMESPACE::PROTOCOL_HTTP;
options.timeout_ms = timeout_ms;
options.max_retry = retry;
int ret = rpc_channel->Init(naming_service_url.c_str(), load_balancer_name.c_str(), &options);
if (ret != 0) {
APP_LOG(ERROR) << "Failed to init channel.";
delete rpc_channel;
destroy_channel_map(channel_map);
return nullptr;
}
} else if (client == "curl") {
// curl does not need to init rpc channel
} else {
APP_LOG(ERROR) << "Unsupported client value [" << client << "].";
destroy_channel_map(channel_map);
return nullptr;
}
} else {
APP_LOG(ERROR) << "Unsupported protocol [" << protocol
<< "] for service [" << service_name << "], skipped...";
destroy_channel_map(channel_map);
return nullptr;
}
RemoteServiceChannel service_channel {
.name = service_name,
.protocol = protocol,
.channel = rpc_channel,
.timeout_ms = timeout_ms,
.max_retry = retry,
.headers = headers
};
channel_map->insert({service_name, service_channel});
APP_LOG(TRACE) << "Loaded service " << service_name;
}
return channel_map;
}
int RemoteServiceManager::call_http_by_BRPC_NAMESPACE(BRPC_NAMESPACE::Channel* channel,
const std::string& url,
const HttpMethod method,
const std::vector>& headers,
const std::string& payload,
std::string& result,
std::string& remote_side,
int& latency) const {
BRPC_NAMESPACE::Controller cntl;
cntl.http_request().uri() = url.c_str();
if (method == HTTP_METHOD_POST) {
cntl.http_request().set_method(BRPC_NAMESPACE::HTTP_METHOD_POST);
cntl.request_attachment().append(payload);
}
for (auto const& header: headers) {
if (header.first == "Content-Type" || header.first == "content-type") {
cntl.http_request().set_content_type(header.second);
continue;
}
cntl.http_request().SetHeader(header.first, header.second);
}
channel->CallMethod(NULL, &cntl, NULL, NULL, NULL);
if (cntl.Failed()) {
APP_LOG(WARNING) << "Call failed, error: " << cntl.ErrorText();
remote_side = BUTIL_NAMESPACE::endpoint2str(cntl.remote_side()).c_str();
latency = cntl.latency_us() / 1000;
return -1;
}
result = cntl.response_attachment().to_string();
remote_side = BUTIL_NAMESPACE::endpoint2str(cntl.remote_side()).c_str();
latency = cntl.latency_us() / 1000;
return 0;
}
static size_t curl_write_callback(void *contents, size_t size, size_t nmemb, void *userp) {
BUTIL_NAMESPACE::IOBuf* buffer = static_cast(userp);
size_t realsize = size * nmemb;
buffer->append(contents, realsize);
return realsize;
}
int RemoteServiceManager::call_http_by_curl(const std::string& url,
const HttpMethod method,
const std::vector>& headers,
const std::string& payload,
const int timeout_ms,
const int max_retry,
std::string& result,
std::string& remote_side,
int& latency) const {
// curl does not support retry now
(void)max_retry;
CURL *curl;
CURLcode res;
struct curl_slist *curl_headers = nullptr;
curl = curl_easy_init();
if (!curl) {
APP_LOG(ERROR) << "Failed to init curl";
return -1;
}
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
BUTIL_NAMESPACE::IOBuf response_buffer;
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curl_write_callback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, static_cast(&response_buffer));
if (method == HTTP_METHOD_POST) {
curl_easy_setopt(curl, CURLOPT_POST, 1L);
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, payload.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, payload.length());
}
for (auto const& header: headers) {
std::string header_value = header.first;
header_value += ": ";
header_value += header.second;
curl_headers = curl_slist_append(curl_headers, header_value.c_str());
}
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers);
curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, timeout_ms);
//curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L);
res = curl_easy_perform(curl);
curl_slist_free_all(curl_headers);
if(res != CURLE_OK) {
APP_LOG(ERROR) << "curl failed, error: " << curl_easy_strerror(res);
curl_easy_cleanup(curl);
return -1;
}
double total_time;
res = curl_easy_getinfo(curl, CURLINFO_TOTAL_TIME, &total_time);
if (CURLE_OK == res) {
latency = total_time * 1000;
}
char *ip = nullptr;
res = curl_easy_getinfo(curl, CURLINFO_PRIMARY_IP, &ip);
if (CURLE_OK == res && ip != nullptr) {
remote_side = ip;
}
curl_easy_cleanup(curl);
result = response_buffer.to_string();
return 0;
}
} // namespace dmkit
================================================
FILE: src/remote_service_manager.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_REMOTE_SERVICE_MANAGER_H
#define DMKIT_REMOTE_SERVICE_MANAGER_H
#include
#include
#include
#include
#include
#include "brpc.h"
#include "butil.h"
namespace dmkit {
enum HttpMethod {
// Set to the same number as the method defined in baidu::rpc::HttpMethod
HTTP_METHOD_DELETE = 0,
HTTP_METHOD_GET = 1,
HTTP_METHOD_POST = 3,
HTTP_METHOD_PUT = 4
};
struct RemoteServiceParam {
// Url for http service
std::string url;
// Http method
HttpMethod http_method;
// Request body
std::string payload;
};
struct RemoteServiceResult {
// Response body
std::string result;
};
struct RemoteServiceChannel {
// Name of the channel for rpc call
std::string name;
// Protocol such as http
std::string protocol;
// Rpc channel instance
BRPC_NAMESPACE::Channel *channel;
// timeout in milliseconds
int timeout_ms;
// retry count
int max_retry;
// Headers for http procotol
std::vector> headers;
};
// Type for channel map
typedef std::unordered_map ChannelMap;
// A configurable remote service manager class.
// All remote service channels are created with configuration file
// when initialization. Caller calls a remote service by supplying
// the service name and other parameters.
class RemoteServiceManager {
public:
RemoteServiceManager();
~RemoteServiceManager();
// Initalization with a json configuration file.
int init(const char *path, const char *conf);
// Reload config.
int reload();
// Callback when service conf changed.
static int service_conf_change_callback(void* param);
// Call a remote service with specifid service name.
int call(const std::string& servie_name,
const RemoteServiceParam& params,
RemoteServiceResult &result) const;
private:
// Http is the most common protocol.
int call_http_by_BRPC_NAMESPACE(BRPC_NAMESPACE::Channel* channel,
const std::string& url,
const HttpMethod method,
const std::vector>& headers,
const std::string& payload,
std::string& result,
std::string& remote_side,
int& latency) const;
int call_http_by_curl(const std::string& url,
const HttpMethod method,
const std::vector>& headers,
const std::string& payload,
const int timeout_ms,
const int max_retry,
std::string& result,
std::string& remote_side,
int& latency) const;
ChannelMap* load_channel_map();
std::string _conf_file_path;
std::shared_ptr _p_channel_map;
};
} // namespace dmkit
#endif //DMKIT_THREAD_DATA_BASE_H
================================================
FILE: src/request_context.cpp
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "request_context.h"
namespace dmkit {
RequestContext::RequestContext(RemoteServiceManager* remote_service_manager,
const std::string& qid,
const std::unordered_map& params)
: _remote_service_manager(remote_service_manager), _qid(qid), _params(params) {
}
RequestContext::~RequestContext() {
}
const RemoteServiceManager* RequestContext::remote_service_manager() const {
return _remote_service_manager;
}
const std::string& RequestContext::qid() const {
return _qid;
}
const std::unordered_map& RequestContext::params() const {
return _params;
}
bool RequestContext::set_param_value(const std::string& param_name, const std::string& value) {
this->_params[param_name] = value;
return true;
}
bool RequestContext::try_get_param(const std::string& param_name, std::string& value) const {
value.clear();
auto search = this->_params.find(param_name);
if (search == this->_params.end()) {
return false;
}
value = search->second;
return true;
}
} // namespace dmkit
================================================
FILE: src/request_context.h
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DMKIT_REQUEST_CONTEXT_H
#define DMKIT_REQUEST_CONTEXT_H
#include
#include
#include "remote_service_manager.h"
namespace dmkit {
// Request context for user functions to access,
// including request parameters and an remote_service_manager instance
class RequestContext {
public:
RequestContext(RemoteServiceManager* remote_service_manager,
const std::string& qid,
const std::unordered_map& params);
~RequestContext();
const RemoteServiceManager* remote_service_manager() const;
const std::string& qid() const;
const std::unordered_map& params() const;
bool set_param_value(const std::string& param_name, const std::string& value);
bool try_get_param(const std::string& param_name, std::string& value) const;
private:
RemoteServiceManager* _remote_service_manager;
std::string _qid;
std::unordered_map _params;
};
} // namespace dmkit
#endif //DMKIT_THREAD_DATA_BASE_H
================================================
FILE: src/server.cpp
================================================
// Copyright (c) 2018 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include
#include "app_container.h"
#include "app_log.h"
#include "brpc.h"
#include "http.pb.h"
DEFINE_int32(port, -1, "TCP Port of this server");
DEFINE_int32(internal_port, -1, "Only allow builtin services at this port");
DEFINE_int32(idle_timeout_s, -1, "Connection will be closed if there is no read/write operations during this time");
DEFINE_int32(max_concurrency, 0, "Limit of requests processing in parallel");
DEFINE_string(url_path, "", "Url path of the app");
DEFINE_bool(log_to_file, false, "Log to file");
namespace dmkit {
// Service with static path.
class HttpServiceImpl : public HttpService {
public:
HttpServiceImpl() {};
virtual ~HttpServiceImpl() noexcept {};
int init() {
return this->_app_container.load_application();
}
ThreadLocalDataFactory* get_thread_local_data_factory() {
return this->_app_container.get_thread_local_data_factory();
}
void run(google::protobuf::RpcController* cntl_base,
const HttpRequest*,
HttpResponse*,
google::protobuf::Closure* done) {
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
BRPC_NAMESPACE::ClosureGuard done_guard(done);
BRPC_NAMESPACE::Controller* cntl = static_cast(cntl_base);
this->_app_container.run(cntl);
}
private:
AppContainer _app_container;
};
} // namespace dmkit
int main(int argc, char* argv[]) {
// Parse gflags. We recommend you to use gflags as well.
GFLAGS_NS::SetCommandLineOption("flagfile", "conf/gflags.conf");
GFLAGS_NS::ParseCommandLineFlags(&argc, &argv, true);
#ifdef BUTIL_ENABLE_COMLOG_SINK
if (FLAGS_log_to_file) {
if (logging::ComlogSink::GetInstance()->SetupFromConfig("conf/log.conf") != 0) {
APP_LOG(ERROR) << "Fail to setup comlog from conf/log.conf";
return -1;
}
}
#endif
// Generally you only need one Server.
BRPC_NAMESPACE::Server server;
dmkit::HttpServiceImpl http_svc;
if (http_svc.init() != 0) {
APP_LOG(ERROR) << "Fail to init http_svc";
return -1;
}
// Add services into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise
// use baidu::rpc::SERVER_OWNS_SERVICE.
std::string mapping = FLAGS_url_path + " => run";
if (server.AddService(&http_svc,
BRPC_NAMESPACE::SERVER_DOESNT_OWN_SERVICE,
mapping.c_str()) != 0) {
APP_LOG(ERROR) << "Fail to add http_svc";
return -1;
}
// Start the server.
BRPC_NAMESPACE::ServerOptions options;
options.idle_timeout_sec = FLAGS_idle_timeout_s;
options.max_concurrency = FLAGS_max_concurrency;
options.internal_port = FLAGS_internal_port;
options.thread_local_data_factory = http_svc.get_thread_local_data_factory();
APP_LOG(TRACE) << "Starting server...";
if (server.Start(FLAGS_port, &options) != 0) {
APP_LOG(ERROR) << "Fail to start server";
return -1;
}
// Wait until Ctrl-C is pressed, then Stop() and Join() the server.
server.RunUntilAskedToQuit();
return 0;
}
================================================
FILE: src/thirdparty/rapidjson/allocators.h
================================================
// Tencent is pleased to support the open source community by making RapidJSON available.
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. All rights reserved.
//
// Licensed under the MIT License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_ALLOCATORS_H_
#define RAPIDJSON_ALLOCATORS_H_
#include "rapidjson.h"
RAPIDJSON_NAMESPACE_BEGIN
///////////////////////////////////////////////////////////////////////////////
// Allocator
/*! \class rapidjson::Allocator
\brief Concept for allocating, resizing and freeing memory block.
Note that Malloc() and Realloc() are non-static but Free() is static.
So if an allocator need to support Free(), it needs to put its pointer in
the header of memory block.
\code
concept Allocator {
static const bool kNeedFree; //!< Whether this allocator needs to call Free().
// Allocate a memory block.
// \param size of the memory block in bytes.
// \returns pointer to the memory block.
void* Malloc(size_t size);
// Resize a memory block.
// \param originalPtr The pointer to current memory block. Null pointer is permitted.
// \param originalSize The current size in bytes. (Design issue: since some allocator may not book-keep this, explicitly pass to it can save memory.)
// \param newSize the new size in bytes.
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize);
// Free a memory block.
// \param pointer to the memory block. Null pointer is permitted.
static void Free(void *ptr);
};
\endcode
*/
///////////////////////////////////////////////////////////////////////////////
// CrtAllocator
//! C-runtime library allocator.
/*! This class is just wrapper for standard C library memory routines.
\note implements Allocator concept
*/
class CrtAllocator {
public:
static const bool kNeedFree = true;
void* Malloc(size_t size) {
if (size) // behavior of malloc(0) is implementation defined.
return std::malloc(size);
else
return NULL; // standardize to returning NULL.
}
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize) {
(void)originalSize;
if (newSize == 0) {
std::free(originalPtr);
return NULL;
}
return std::realloc(originalPtr, newSize);
}
static void Free(void *ptr) { std::free(ptr); }
};
///////////////////////////////////////////////////////////////////////////////
// MemoryPoolAllocator
//! Default memory allocator used by the parser and DOM.
/*! This allocator allocate memory blocks from pre-allocated memory chunks.
It does not free memory blocks. And Realloc() only allocate new memory.
The memory chunks are allocated by BaseAllocator, which is CrtAllocator by default.
User may also supply a buffer as the first chunk.
If the user-buffer is full then additional chunks are allocated by BaseAllocator.
The user-buffer is not deallocated by this allocator.
\tparam BaseAllocator the allocator type for allocating memory chunks. Default is CrtAllocator.
\note implements Allocator concept
*/
template
class MemoryPoolAllocator {
public:
static const bool kNeedFree = false; //!< Tell users that no need to call Free() with this allocator. (concept Allocator)
//! Constructor with chunkSize.
/*! \param chunkSize The size of memory chunk. The default is kDefaultChunkSize.
\param baseAllocator The allocator for allocating memory chunks.
*/
MemoryPoolAllocator(size_t chunkSize = kDefaultChunkCapacity, BaseAllocator* baseAllocator = 0) :
chunkHead_(0), chunk_capacity_(chunkSize), userBuffer_(0), baseAllocator_(baseAllocator), ownBaseAllocator_(0)
{
}
//! Constructor with user-supplied buffer.
/*! The user buffer will be used firstly. When it is full, memory pool allocates new chunk with chunk size.
The user buffer will not be deallocated when this allocator is destructed.
\param buffer User supplied buffer.
\param size Size of the buffer in bytes. It must at least larger than sizeof(ChunkHeader).
\param chunkSize The size of memory chunk. The default is kDefaultChunkSize.
\param baseAllocator The allocator for allocating memory chunks.
*/
MemoryPoolAllocator(void *buffer, size_t size, size_t chunkSize = kDefaultChunkCapacity, BaseAllocator* baseAllocator = 0) :
chunkHead_(0), chunk_capacity_(chunkSize), userBuffer_(buffer), baseAllocator_(baseAllocator), ownBaseAllocator_(0)
{
RAPIDJSON_ASSERT(buffer != 0);
RAPIDJSON_ASSERT(size > sizeof(ChunkHeader));
chunkHead_ = reinterpret_cast(buffer);
chunkHead_->capacity = size - sizeof(ChunkHeader);
chunkHead_->size = 0;
chunkHead_->next = 0;
}
//! Destructor.
/*! This deallocates all memory chunks, excluding the user-supplied buffer.
*/
~MemoryPoolAllocator() {
Clear();
RAPIDJSON_DELETE(ownBaseAllocator_);
}
//! Deallocates all memory chunks, excluding the user-supplied buffer.
void Clear() {
while (chunkHead_ && chunkHead_ != userBuffer_) {
ChunkHeader* next = chunkHead_->next;
baseAllocator_->Free(chunkHead_);
chunkHead_ = next;
}
if (chunkHead_ && chunkHead_ == userBuffer_)
chunkHead_->size = 0; // Clear user buffer
}
//! Computes the total capacity of allocated memory chunks.
/*! \return total capacity in bytes.
*/
size_t Capacity() const {
size_t capacity = 0;
for (ChunkHeader* c = chunkHead_; c != 0; c = c->next)
capacity += c->capacity;
return capacity;
}
//! Computes the memory blocks allocated.
/*! \return total used bytes.
*/
size_t Size() const {
size_t size = 0;
for (ChunkHeader* c = chunkHead_; c != 0; c = c->next)
size += c->size;
return size;
}
//! Allocates a memory block. (concept Allocator)
void* Malloc(size_t size) {
if (!size)
return NULL;
size = RAPIDJSON_ALIGN(size);
if (chunkHead_ == 0 || chunkHead_->size + size > chunkHead_->capacity)
if (!AddChunk(chunk_capacity_ > size ? chunk_capacity_ : size))
return NULL;
void *buffer = reinterpret_cast(chunkHead_) + RAPIDJSON_ALIGN(sizeof(ChunkHeader)) + chunkHead_->size;
chunkHead_->size += size;
return buffer;
}
//! Resizes a memory block (concept Allocator)
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize) {
if (originalPtr == 0)
return Malloc(newSize);
if (newSize == 0)
return NULL;
originalSize = RAPIDJSON_ALIGN(originalSize);
newSize = RAPIDJSON_ALIGN(newSize);
// Do not shrink if new size is smaller than original
if (originalSize >= newSize)
return originalPtr;
// Simply expand it if it is the last allocation and there is sufficient space
if (originalPtr == reinterpret_cast(chunkHead_) + RAPIDJSON_ALIGN(sizeof(ChunkHeader)) + chunkHead_->size - originalSize) {
size_t increment = static_cast(newSize - originalSize);
if (chunkHead_->size + increment <= chunkHead_->capacity) {
chunkHead_->size += increment;
return originalPtr;
}
}
// Realloc process: allocate and copy memory, do not free original buffer.
if (void* newBuffer = Malloc(newSize)) {
if (originalSize)
std::memcpy(newBuffer, originalPtr, originalSize);
return newBuffer;
}
else
return NULL;
}
//! Frees a memory block (concept Allocator)
static void Free(void *ptr) { (void)ptr; } // Do nothing
private:
//! Copy constructor is not permitted.
MemoryPoolAllocator(const MemoryPoolAllocator& rhs) /* = delete */;
//! Copy assignment operator is not permitted.
MemoryPoolAllocator& operator=(const MemoryPoolAllocator& rhs) /* = delete */;
//! Creates a new chunk.
/*! \param capacity Capacity of the chunk in bytes.
\return true if success.
*/
bool AddChunk(size_t capacity) {
if (!baseAllocator_)
ownBaseAllocator_ = baseAllocator_ = RAPIDJSON_NEW(BaseAllocator());
if (ChunkHeader* chunk = reinterpret_cast(baseAllocator_->Malloc(RAPIDJSON_ALIGN(sizeof(ChunkHeader)) + capacity))) {
chunk->capacity = capacity;
chunk->size = 0;
chunk->next = chunkHead_;
chunkHead_ = chunk;
return true;
}
else
return false;
}
static const int kDefaultChunkCapacity = 64 * 1024; //!< Default chunk capacity.
//! Chunk header for perpending to each chunk.
/*! Chunks are stored as a singly linked list.
*/
struct ChunkHeader {
size_t capacity; //!< Capacity of the chunk in bytes (excluding the header itself).
size_t size; //!< Current size of allocated memory in bytes.
ChunkHeader *next; //!< Next chunk in the linked list.
};
ChunkHeader *chunkHead_; //!< Head of the chunk linked-list. Only the head chunk serves allocation.
size_t chunk_capacity_; //!< The minimum capacity of chunk when they are allocated.
void *userBuffer_; //!< User supplied buffer.
BaseAllocator* baseAllocator_; //!< base allocator for allocating memory chunks.
BaseAllocator* ownBaseAllocator_; //!< base allocator created by this object.
};
RAPIDJSON_NAMESPACE_END
#endif // RAPIDJSON_ENCODINGS_H_
================================================
FILE: src/thirdparty/rapidjson/document.h
================================================
// Tencent is pleased to support the open source community by making RapidJSON available.
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. All rights reserved.
//
// Licensed under the MIT License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_DOCUMENT_H_
#define RAPIDJSON_DOCUMENT_H_
/*! \file document.h */
#include "reader.h"
#include "internal/meta.h"
#include "internal/strfunc.h"
#include "memorystream.h"
#include "encodedstream.h"
#include // placement new
#include
RAPIDJSON_DIAG_PUSH
#ifdef _MSC_VER
RAPIDJSON_DIAG_OFF(4127) // conditional expression is constant
RAPIDJSON_DIAG_OFF(4244) // conversion from kXxxFlags to 'uint16_t', possible loss of data
#endif
#ifdef __clang__
RAPIDJSON_DIAG_OFF(padded)
RAPIDJSON_DIAG_OFF(switch-enum)
RAPIDJSON_DIAG_OFF(c++98-compat)
#endif
#ifdef __GNUC__
RAPIDJSON_DIAG_OFF(effc++)
#if __GNUC__ >= 6
RAPIDJSON_DIAG_OFF(terminate) // ignore throwing RAPIDJSON_ASSERT in RAPIDJSON_NOEXCEPT functions
#endif
#endif // __GNUC__
#ifndef RAPIDJSON_NOMEMBERITERATORCLASS
#include // std::iterator, std::random_access_iterator_tag
#endif
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
#include // std::move
#endif
RAPIDJSON_NAMESPACE_BEGIN
// Forward declaration.
template
class GenericValue;
template
class GenericDocument;
//! Name-value pair in a JSON object value.
/*!
This class was internal to GenericValue. It used to be a inner struct.
But a compiler (IBM XL C/C++ for AIX) have reported to have problem with that so it moved as a namespace scope struct.
https://code.google.com/p/rapidjson/issues/detail?id=64
*/
template
struct GenericMember {
GenericValue name; //!< name of member (must be a string)
GenericValue value; //!< value of member.
};
///////////////////////////////////////////////////////////////////////////////
// GenericMemberIterator
#ifndef RAPIDJSON_NOMEMBERITERATORCLASS
//! (Constant) member iterator for a JSON object value
/*!
\tparam Const Is this a constant iterator?
\tparam Encoding Encoding of the value. (Even non-string values need to have the same encoding in a document)
\tparam Allocator Allocator type for allocating memory of object, array and string.
This class implements a Random Access Iterator for GenericMember elements
of a GenericValue, see ISO/IEC 14882:2003(E) C++ standard, 24.1 [lib.iterator.requirements].
\note This iterator implementation is mainly intended to avoid implicit
conversions from iterator values to \c NULL,
e.g. from GenericValue::FindMember.
\note Define \c RAPIDJSON_NOMEMBERITERATORCLASS to fall back to a
pointer-based implementation, if your platform doesn't provide
the C++ header.
\see GenericMember, GenericValue::MemberIterator, GenericValue::ConstMemberIterator
*/
template
class GenericMemberIterator
: public std::iterator >::Type> {
friend class GenericValue;
template friend class GenericMemberIterator;
typedef GenericMember PlainType;
typedef typename internal::MaybeAddConst::Type ValueType;
typedef std::iterator BaseType;
public:
//! Iterator type itself
typedef GenericMemberIterator Iterator;
//! Constant iterator type
typedef GenericMemberIterator ConstIterator;
//! Non-constant iterator type
typedef GenericMemberIterator NonConstIterator;
//! Pointer to (const) GenericMember
typedef typename BaseType::pointer Pointer;
//! Reference to (const) GenericMember
typedef typename BaseType::reference Reference;
//! Signed integer type (e.g. \c ptrdiff_t)
typedef typename BaseType::difference_type DifferenceType;
//! Default constructor (singular value)
/*! Creates an iterator pointing to no element.
\note All operations, except for comparisons, are undefined on such values.
*/
GenericMemberIterator() : ptr_() {}
//! Iterator conversions to more const
/*!
\param it (Non-const) iterator to copy from
Allows the creation of an iterator from another GenericMemberIterator
that is "less const". Especially, creating a non-constant iterator
from a constant iterator are disabled:
\li const -> non-const (not ok)
\li const -> const (ok)
\li non-const -> const (ok)
\li non-const -> non-const (ok)
\note If the \c Const template parameter is already \c false, this
constructor effectively defines a regular copy-constructor.
Otherwise, the copy constructor is implicitly defined.
*/
GenericMemberIterator(const NonConstIterator & it) : ptr_(it.ptr_) {}
Iterator& operator=(const NonConstIterator & it) { ptr_ = it.ptr_; return *this; }
//! @name stepping
//@{
Iterator& operator++(){ ++ptr_; return *this; }
Iterator& operator--(){ --ptr_; return *this; }
Iterator operator++(int){ Iterator old(*this); ++ptr_; return old; }
Iterator operator--(int){ Iterator old(*this); --ptr_; return old; }
//@}
//! @name increment/decrement
//@{
Iterator operator+(DifferenceType n) const { return Iterator(ptr_+n); }
Iterator operator-(DifferenceType n) const { return Iterator(ptr_-n); }
Iterator& operator+=(DifferenceType n) { ptr_+=n; return *this; }
Iterator& operator-=(DifferenceType n) { ptr_-=n; return *this; }
//@}
//! @name relations
//@{
bool operator==(ConstIterator that) const { return ptr_ == that.ptr_; }
bool operator!=(ConstIterator that) const { return ptr_ != that.ptr_; }
bool operator<=(ConstIterator that) const { return ptr_ <= that.ptr_; }
bool operator>=(ConstIterator that) const { return ptr_ >= that.ptr_; }
bool operator< (ConstIterator that) const { return ptr_ < that.ptr_; }
bool operator> (ConstIterator that) const { return ptr_ > that.ptr_; }
//@}
//! @name dereference
//@{
Reference operator*() const { return *ptr_; }
Pointer operator->() const { return ptr_; }
Reference operator[](DifferenceType n) const { return ptr_[n]; }
//@}
//! Distance
DifferenceType operator-(ConstIterator that) const { return ptr_-that.ptr_; }
private:
//! Internal constructor from plain pointer
explicit GenericMemberIterator(Pointer p) : ptr_(p) {}
Pointer ptr_; //!< raw pointer
};
#else // RAPIDJSON_NOMEMBERITERATORCLASS
// class-based member iterator implementation disabled, use plain pointers
template
struct GenericMemberIterator;
//! non-const GenericMemberIterator
template
struct GenericMemberIterator {
//! use plain pointer as iterator type
typedef GenericMember* Iterator;
};
//! const GenericMemberIterator
template
struct GenericMemberIterator {
//! use plain const pointer as iterator type
typedef const GenericMember* Iterator;
};
#endif // RAPIDJSON_NOMEMBERITERATORCLASS
///////////////////////////////////////////////////////////////////////////////
// GenericStringRef
//! Reference to a constant string (not taking a copy)
/*!
\tparam CharType character type of the string
This helper class is used to automatically infer constant string
references for string literals, especially from \c const \b (!)
character arrays.
The main use is for creating JSON string values without copying the
source string via an \ref Allocator. This requires that the referenced
string pointers have a sufficient lifetime, which exceeds the lifetime
of the associated GenericValue.
\b Example
\code
Value v("foo"); // ok, no need to copy & calculate length
const char foo[] = "foo";
v.SetString(foo); // ok
const char* bar = foo;
// Value x(bar); // not ok, can't rely on bar's lifetime
Value x(StringRef(bar)); // lifetime explicitly guaranteed by user
Value y(StringRef(bar, 3)); // ok, explicitly pass length
\endcode
\see StringRef, GenericValue::SetString
*/
template
struct GenericStringRef {
typedef CharType Ch; //!< character type of the string
//! Create string reference from \c const character array
#ifndef __clang__ // -Wdocumentation
/*!
This constructor implicitly creates a constant string reference from
a \c const character array. It has better performance than
\ref StringRef(const CharType*) by inferring the string \ref length
from the array length, and also supports strings containing null
characters.
\tparam N length of the string, automatically inferred
\param str Constant character array, lifetime assumed to be longer
than the use of the string in e.g. a GenericValue
\post \ref s == str
\note Constant complexity.
\note There is a hidden, private overload to disallow references to
non-const character arrays to be created via this constructor.
By this, e.g. function-scope arrays used to be filled via
\c snprintf are excluded from consideration.
In such cases, the referenced string should be \b copied to the
GenericValue instead.
*/
#endif
template
GenericStringRef(const CharType (&str)[N]) RAPIDJSON_NOEXCEPT
: s(str), length(N-1) {}
//! Explicitly create string reference from \c const character pointer
#ifndef __clang__ // -Wdocumentation
/*!
This constructor can be used to \b explicitly create a reference to
a constant string pointer.
\see StringRef(const CharType*)
\param str Constant character pointer, lifetime assumed to be longer
than the use of the string in e.g. a GenericValue
\post \ref s == str
\note There is a hidden, private overload to disallow references to
non-const character arrays to be created via this constructor.
By this, e.g. function-scope arrays used to be filled via
\c snprintf are excluded from consideration.
In such cases, the referenced string should be \b copied to the
GenericValue instead.
*/
#endif
explicit GenericStringRef(const CharType* str)
: s(str), length(internal::StrLen(str)){ RAPIDJSON_ASSERT(s != 0); }
//! Create constant string reference from pointer and length
#ifndef __clang__ // -Wdocumentation
/*! \param str constant string, lifetime assumed to be longer than the use of the string in e.g. a GenericValue
\param len length of the string, excluding the trailing NULL terminator
\post \ref s == str && \ref length == len
\note Constant complexity.
*/
#endif
GenericStringRef(const CharType* str, SizeType len)
: s(str), length(len) { RAPIDJSON_ASSERT(s != 0); }
GenericStringRef(const GenericStringRef& rhs) : s(rhs.s), length(rhs.length) {}
GenericStringRef& operator=(const GenericStringRef& rhs) { s = rhs.s; length = rhs.length; }
//! implicit conversion to plain CharType pointer
operator const Ch *() const { return s; }
const Ch* const s; //!< plain CharType pointer
const SizeType length; //!< length of the string (excluding the trailing NULL terminator)
private:
//! Disallow construction from non-const array
template
GenericStringRef(CharType (&str)[N]) /* = delete */;
};
//! Mark a character pointer as constant string
/*! Mark a plain character pointer as a "string literal". This function
can be used to avoid copying a character string to be referenced as a
value in a JSON GenericValue object, if the string's lifetime is known
to be valid long enough.
\tparam CharType Character type of the string
\param str Constant string, lifetime assumed to be longer than the use of the string in e.g. a GenericValue
\return GenericStringRef string reference object
\relatesalso GenericStringRef
\see GenericValue::GenericValue(StringRefType), GenericValue::operator=(StringRefType), GenericValue::SetString(StringRefType), GenericValue::PushBack(StringRefType, Allocator&), GenericValue::AddMember
*/
template
inline GenericStringRef StringRef(const CharType* str) {
return GenericStringRef(str, internal::StrLen(str));
}
//! Mark a character pointer as constant string
/*! Mark a plain character pointer as a "string literal". This function
can be used to avoid copying a character string to be referenced as a
value in a JSON GenericValue object, if the string's lifetime is known
to be valid long enough.
This version has better performance with supplied length, and also
supports string containing null characters.
\tparam CharType character type of the string
\param str Constant string, lifetime assumed to be longer than the use of the string in e.g. a GenericValue
\param length The length of source string.
\return GenericStringRef string reference object
\relatesalso GenericStringRef
*/
template
inline GenericStringRef StringRef(const CharType* str, size_t length) {
return GenericStringRef(str, SizeType(length));
}
#if RAPIDJSON_HAS_STDSTRING
//! Mark a string object as constant string
/*! Mark a string object (e.g. \c std::string) as a "string literal".
This function can be used to avoid copying a string to be referenced as a
value in a JSON GenericValue object, if the string's lifetime is known
to be valid long enough.
\tparam CharType character type of the string
\param str Constant string, lifetime assumed to be longer than the use of the string in e.g. a GenericValue
\return GenericStringRef string reference object
\relatesalso GenericStringRef
\note Requires the definition of the preprocessor symbol \ref RAPIDJSON_HAS_STDSTRING.
*/
template
inline GenericStringRef StringRef(const std::basic_string& str) {
return GenericStringRef(str.data(), SizeType(str.size()));
}
#endif
///////////////////////////////////////////////////////////////////////////////
// GenericValue type traits
namespace internal {
template
struct IsGenericValueImpl : FalseType {};
// select candidates according to nested encoding and allocator types
template struct IsGenericValueImpl::Type, typename Void::Type>
: IsBaseOf, T>::Type {};
// helper to match arbitrary GenericValue instantiations, including derived classes
template struct IsGenericValue : IsGenericValueImpl::Type {};
} // namespace internal
///////////////////////////////////////////////////////////////////////////////
// TypeHelper
namespace internal {
template
struct TypeHelper {};
template
struct TypeHelper {
static bool Is(const ValueType& v) { return v.IsBool(); }
static bool Get(const ValueType& v) { return v.GetBool(); }
static ValueType& Set(ValueType& v, bool data) { return v.SetBool(data); }
static ValueType& Set(ValueType& v, bool data, typename ValueType::AllocatorType&) { return v.SetBool(data); }
};
template
struct TypeHelper {
static bool Is(const ValueType& v) { return v.IsInt(); }
static int Get(const ValueType& v) { return v.GetInt(); }
static ValueType& Set(ValueType& v, int data) { return v.SetInt(data); }
static ValueType& Set(ValueType& v, int data, typename ValueType::AllocatorType&) { return v.SetInt(data); }
};
template
struct TypeHelper {
static bool Is(const ValueType& v) { return v.IsUint(); }
static unsigned Get(const ValueType& v) { return v.GetUint(); }
static ValueType& Set(ValueType& v, unsigned data) { return v.SetUint(data); }
static ValueType& Set(ValueType& v, unsigned data, typename ValueType::AllocatorType&) { return v.SetUint(data); }
};
template
struct TypeHelper {
static bool Is(const ValueType& v) { return v.IsInt64(); }
static int64_t Get(const ValueType& v) { return v.GetInt64(); }
static ValueType& Set(ValueType& v, int64_t data) { return v.SetInt64(data); }
static ValueType& Set(ValueType& v, int64_t data, typename ValueType::AllocatorType&) { return v.SetInt64(data); }
};
template
struct TypeHelper {
static bool Is(const ValueType& v) { return v.IsUint64(); }
static uint64_t Get(const ValueType& v) { return v.GetUint64(); }
static ValueType& Set(ValueType& v, uint64_t data) { return v.SetUint64(data); }
static ValueType& Set(ValueType& v, uint64_t data, typename ValueType::AllocatorType&) { return v.SetUint64(data); }
};
template
struct TypeHelper {
static bool Is(const ValueType& v) { return v.IsDouble(); }
static double Get(const ValueType& v) { return v.GetDouble(); }
static ValueType& Set(ValueType& v, double data) { return v.SetDouble(data); }
static ValueType& Set(ValueType& v, double data, typename ValueType::AllocatorType&) { return v.SetDouble(data); }
};
template
struct TypeHelper {
static bool Is(const ValueType& v) { return v.IsFloat(); }
static float Get(const ValueType& v) { return v.GetFloat(); }
static ValueType& Set(ValueType& v, float data) { return v.SetFloat(data); }
static ValueType& Set(ValueType& v, float data, typename ValueType::AllocatorType&) { return v.SetFloat(data); }
};
template
struct TypeHelper {
typedef const typename ValueType::Ch* StringType;
static bool Is(const ValueType& v) { return v.IsString(); }
static StringType Get(const ValueType& v) { return v.GetString(); }
static ValueType& Set(ValueType& v, const StringType data) { return v.SetString(typename ValueType::StringRefType(data)); }
static ValueType& Set(ValueType& v, const StringType data, typename ValueType::AllocatorType& a) { return v.SetString(data, a); }
};
#if RAPIDJSON_HAS_STDSTRING
template
struct TypeHelper > {
typedef std::basic_string StringType;
static bool Is(const ValueType& v) { return v.IsString(); }
static StringType Get(const ValueType& v) { return StringType(v.GetString(), v.GetStringLength()); }
static ValueType& Set(ValueType& v, const StringType& data, typename ValueType::AllocatorType& a) { return v.SetString(data, a); }
};
#endif
template
struct TypeHelper {
typedef typename ValueType::Array ArrayType;
static bool Is(const ValueType& v) { return v.IsArray(); }
static ArrayType Get(ValueType& v) { return v.GetArray(); }
static ValueType& Set(ValueType& v, ArrayType data) { return v = data; }
static ValueType& Set(ValueType& v, ArrayType data, typename ValueType::AllocatorType&) { return v = data; }
};
template
struct TypeHelper {
typedef typename ValueType::ConstArray ArrayType;
static bool Is(const ValueType& v) { return v.IsArray(); }
static ArrayType Get(const ValueType& v) { return v.GetArray(); }
};
template
struct TypeHelper {
typedef typename ValueType::Object ObjectType;
static bool Is(const ValueType& v) { return v.IsObject(); }
static ObjectType Get(ValueType& v) { return v.GetObject(); }
static ValueType& Set(ValueType& v, ObjectType data) { return v = data; }
static ValueType& Set(ValueType& v, ObjectType data, typename ValueType::AllocatorType&) { v = data; }
};
template
struct TypeHelper {
typedef typename ValueType::ConstObject ObjectType;
static bool Is(const ValueType& v) { return v.IsObject(); }
static ObjectType Get(const ValueType& v) { return v.GetObject(); }
};
} // namespace internal
// Forward declarations
template class GenericArray;
template class GenericObject;
///////////////////////////////////////////////////////////////////////////////
// GenericValue
//! Represents a JSON value. Use Value for UTF8 encoding and default allocator.
/*!
A JSON value can be one of 7 types. This class is a variant type supporting
these types.
Use the Value if UTF8 and default allocator
\tparam Encoding Encoding of the value. (Even non-string values need to have the same encoding in a document)
\tparam Allocator Allocator type for allocating memory of object, array and string.
*/
template >
class GenericValue {
public:
//! Name-value pair in an object.
typedef GenericMember Member;
typedef Encoding EncodingType; //!< Encoding type from template parameter.
typedef Allocator AllocatorType; //!< Allocator type from template parameter.
typedef typename Encoding::Ch Ch; //!< Character type derived from Encoding.
typedef GenericStringRef StringRefType; //!< Reference to a constant string
typedef typename GenericMemberIterator::Iterator MemberIterator; //!< Member iterator for iterating in object.
typedef typename GenericMemberIterator::Iterator ConstMemberIterator; //!< Constant member iterator for iterating in object.
typedef GenericValue* ValueIterator; //!< Value iterator for iterating in array.
typedef const GenericValue* ConstValueIterator; //!< Constant value iterator for iterating in array.
typedef GenericValue ValueType; //!< Value type of itself.
typedef GenericArray Array;
typedef GenericArray ConstArray;
typedef GenericObject Object;
typedef GenericObject ConstObject;
//!@name Constructors and destructor.
//@{
//! Default constructor creates a null value.
GenericValue() RAPIDJSON_NOEXCEPT : data_() { data_.f.flags = kNullFlag; }
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
//! Move constructor in C++11
GenericValue(GenericValue&& rhs) RAPIDJSON_NOEXCEPT : data_(rhs.data_) {
rhs.data_.f.flags = kNullFlag; // give up contents
}
#endif
private:
//! Copy constructor is not permitted.
GenericValue(const GenericValue& rhs);
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
//! Moving from a GenericDocument is not permitted.
template
GenericValue(GenericDocument&& rhs);
//! Move assignment from a GenericDocument is not permitted.
template
GenericValue& operator=(GenericDocument&& rhs);
#endif
public:
//! Constructor with JSON value type.
/*! This creates a Value of specified type with default content.
\param type Type of the value.
\note Default content for number is zero.
*/
explicit GenericValue(Type type) RAPIDJSON_NOEXCEPT : data_() {
static const uint16_t defaultFlags[7] = {
kNullFlag, kFalseFlag, kTrueFlag, kObjectFlag, kArrayFlag, kShortStringFlag,
kNumberAnyFlag
};
RAPIDJSON_ASSERT(type <= kNumberType);
data_.f.flags = defaultFlags[type];
// Use ShortString to store empty string.
if (type == kStringType)
data_.ss.SetLength(0);
}
//! Explicit copy constructor (with allocator)
/*! Creates a copy of a Value by using the given Allocator
\tparam SourceAllocator allocator of \c rhs
\param rhs Value to copy from (read-only)
\param allocator Allocator for allocating copied elements and buffers. Commonly use GenericDocument::GetAllocator().
\see CopyFrom()
*/
template< typename SourceAllocator >
GenericValue(const GenericValue& rhs, Allocator & allocator);
//! Constructor for boolean value.
/*! \param b Boolean value
\note This constructor is limited to \em real boolean values and rejects
implicitly converted types like arbitrary pointers. Use an explicit cast
to \c bool, if you want to construct a boolean JSON value in such cases.
*/
#ifndef RAPIDJSON_DOXYGEN_RUNNING // hide SFINAE from Doxygen
template
explicit GenericValue(T b, RAPIDJSON_ENABLEIF((internal::IsSame))) RAPIDJSON_NOEXCEPT // See #472
#else
explicit GenericValue(bool b) RAPIDJSON_NOEXCEPT
#endif
: data_() {
// safe-guard against failing SFINAE
RAPIDJSON_STATIC_ASSERT((internal::IsSame::Value));
data_.f.flags = b ? kTrueFlag : kFalseFlag;
}
//! Constructor for int value.
explicit GenericValue(int i) RAPIDJSON_NOEXCEPT : data_() {
data_.n.i64 = i;
data_.f.flags = (i >= 0) ? (kNumberIntFlag | kUintFlag | kUint64Flag) : kNumberIntFlag;
}
//! Constructor for unsigned value.
explicit GenericValue(unsigned u) RAPIDJSON_NOEXCEPT : data_() {
data_.n.u64 = u;
data_.f.flags = (u & 0x80000000) ? kNumberUintFlag : (kNumberUintFlag | kIntFlag | kInt64Flag);
}
//! Constructor for int64_t value.
explicit GenericValue(int64_t i64) RAPIDJSON_NOEXCEPT : data_() {
data_.n.i64 = i64;
data_.f.flags = kNumberInt64Flag;
if (i64 >= 0) {
data_.f.flags |= kNumberUint64Flag;
if (!(static_cast(i64) & RAPIDJSON_UINT64_C2(0xFFFFFFFF, 0x00000000)))
data_.f.flags |= kUintFlag;
if (!(static_cast(i64) & RAPIDJSON_UINT64_C2(0xFFFFFFFF, 0x80000000)))
data_.f.flags |= kIntFlag;
}
else if (i64 >= static_cast(RAPIDJSON_UINT64_C2(0xFFFFFFFF, 0x80000000)))
data_.f.flags |= kIntFlag;
}
//! Constructor for uint64_t value.
explicit GenericValue(uint64_t u64) RAPIDJSON_NOEXCEPT : data_() {
data_.n.u64 = u64;
data_.f.flags = kNumberUint64Flag;
if (!(u64 & RAPIDJSON_UINT64_C2(0x80000000, 0x00000000)))
data_.f.flags |= kInt64Flag;
if (!(u64 & RAPIDJSON_UINT64_C2(0xFFFFFFFF, 0x00000000)))
data_.f.flags |= kUintFlag;
if (!(u64 & RAPIDJSON_UINT64_C2(0xFFFFFFFF, 0x80000000)))
data_.f.flags |= kIntFlag;
}
//! Constructor for double value.
explicit GenericValue(double d) RAPIDJSON_NOEXCEPT : data_() { data_.n.d = d; data_.f.flags = kNumberDoubleFlag; }
//! Constructor for constant string (i.e. do not make a copy of string)
GenericValue(const Ch* s, SizeType length) RAPIDJSON_NOEXCEPT : data_() { SetStringRaw(StringRef(s, length)); }
//! Constructor for constant string (i.e. do not make a copy of string)
explicit GenericValue(StringRefType s) RAPIDJSON_NOEXCEPT : data_() { SetStringRaw(s); }
//! Constructor for copy-string (i.e. do make a copy of string)
GenericValue(const Ch* s, SizeType length, Allocator& allocator) : data_() { SetStringRaw(StringRef(s, length), allocator); }
//! Constructor for copy-string (i.e. do make a copy of string)
GenericValue(const Ch*s, Allocator& allocator) : data_() { SetStringRaw(StringRef(s), allocator); }
#if RAPIDJSON_HAS_STDSTRING
//! Constructor for copy-string from a string object (i.e. do make a copy of string)
/*! \note Requires the definition of the preprocessor symbol \ref RAPIDJSON_HAS_STDSTRING.
*/
GenericValue(const std::basic_string& s, Allocator& allocator) : data_() { SetStringRaw(StringRef(s), allocator); }
#endif
//! Constructor for Array.
/*!
\param a An array obtained by \c GetArray().
\note \c Array is always pass-by-value.
\note the source array is moved into this value and the sourec array becomes empty.
*/
GenericValue(Array a) RAPIDJSON_NOEXCEPT : data_(a.value_.data_) {
a.value_.data_ = Data();
a.value_.data_.f.flags = kArrayFlag;
}
//! Constructor for Object.
/*!
\param o An object obtained by \c GetObject().
\note \c Object is always pass-by-value.
\note the source object is moved into this value and the sourec object becomes empty.
*/
GenericValue(Object o) RAPIDJSON_NOEXCEPT : data_(o.value_.data_) {
o.value_.data_ = Data();
o.value_.data_.f.flags = kObjectFlag;
}
//! Destructor.
/*! Need to destruct elements of array, members of object, or copy-string.
*/
~GenericValue() {
if (Allocator::kNeedFree) { // Shortcut by Allocator's trait
switch(data_.f.flags) {
case kArrayFlag:
{
GenericValue* e = GetElementsPointer();
for (GenericValue* v = e; v != e + data_.a.size; ++v)
v->~GenericValue();
Allocator::Free(e);
}
break;
case kObjectFlag:
for (MemberIterator m = MemberBegin(); m != MemberEnd(); ++m)
m->~Member();
Allocator::Free(GetMembersPointer());
break;
case kCopyStringFlag:
Allocator::Free(const_cast(GetStringPointer()));
break;
default:
break; // Do nothing for other types.
}
}
}
//@}
//!@name Assignment operators
//@{
//! Assignment with move semantics.
/*! \param rhs Source of the assignment. It will become a null value after assignment.
*/
GenericValue& operator=(GenericValue& rhs) RAPIDJSON_NOEXCEPT {
RAPIDJSON_ASSERT(this != &rhs);
this->~GenericValue();
RawAssign(rhs);
return *this;
}
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
//! Move assignment in C++11
GenericValue& operator=(GenericValue&& rhs) RAPIDJSON_NOEXCEPT {
return *this = rhs.Move();
}
#endif
//! Assignment of constant string reference (no copy)
/*! \param str Constant string reference to be assigned
\note This overload is needed to avoid clashes with the generic primitive type assignment overload below.
\see GenericStringRef, operator=(T)
*/
GenericValue& operator=(StringRefType str) RAPIDJSON_NOEXCEPT {
GenericValue s(str);
return *this = s;
}
//! Assignment with primitive types.
/*! \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t
\param value The value to be assigned.
\note The source type \c T explicitly disallows all pointer types,
especially (\c const) \ref Ch*. This helps avoiding implicitly
referencing character strings with insufficient lifetime, use
\ref SetString(const Ch*, Allocator&) (for copying) or
\ref StringRef() (to explicitly mark the pointer as constant) instead.
All other pointer types would implicitly convert to \c bool,
use \ref SetBool() instead.
*/
template
RAPIDJSON_DISABLEIF_RETURN((internal::IsPointer), (GenericValue&))
operator=(T value) {
GenericValue v(value);
return *this = v;
}
//! Deep-copy assignment from Value
/*! Assigns a \b copy of the Value to the current Value object
\tparam SourceAllocator Allocator type of \c rhs
\param rhs Value to copy from (read-only)
\param allocator Allocator to use for copying
*/
template
GenericValue& CopyFrom(const GenericValue& rhs, Allocator& allocator) {
RAPIDJSON_ASSERT(static_cast(this) != static_cast(&rhs));
this->~GenericValue();
new (this) GenericValue(rhs, allocator);
return *this;
}
//! Exchange the contents of this value with those of other.
/*!
\param other Another value.
\note Constant complexity.
*/
GenericValue& Swap(GenericValue& other) RAPIDJSON_NOEXCEPT {
GenericValue temp;
temp.RawAssign(*this);
RawAssign(other);
other.RawAssign(temp);
return *this;
}
//! free-standing swap function helper
/*!
Helper function to enable support for common swap implementation pattern based on \c std::swap:
\code
void swap(MyClass& a, MyClass& b) {
using std::swap;
swap(a.value, b.value);
// ...
}
\endcode
\see Swap()
*/
friend inline void swap(GenericValue& a, GenericValue& b) RAPIDJSON_NOEXCEPT { a.Swap(b); }
//! Prepare Value for move semantics
/*! \return *this */
GenericValue& Move() RAPIDJSON_NOEXCEPT { return *this; }
//@}
//!@name Equal-to and not-equal-to operators
//@{
//! Equal-to operator
/*!
\note If an object contains duplicated named member, comparing equality with any object is always \c false.
\note Linear time complexity (number of all values in the subtree and total lengths of all strings).
*/
template
bool operator==(const GenericValue& rhs) const {
typedef GenericValue RhsType;
if (GetType() != rhs.GetType())
return false;
switch (GetType()) {
case kObjectType: // Warning: O(n^2) inner-loop
if (data_.o.size != rhs.data_.o.size)
return false;
for (ConstMemberIterator lhsMemberItr = MemberBegin(); lhsMemberItr != MemberEnd(); ++lhsMemberItr) {
typename RhsType::ConstMemberIterator rhsMemberItr = rhs.FindMember(lhsMemberItr->name);
if (rhsMemberItr == rhs.MemberEnd() || lhsMemberItr->value != rhsMemberItr->value)
return false;
}
return true;
case kArrayType:
if (data_.a.size != rhs.data_.a.size)
return false;
for (SizeType i = 0; i < data_.a.size; i++)
if ((*this)[i] != rhs[i])
return false;
return true;
case kStringType:
return StringEqual(rhs);
case kNumberType:
if (IsDouble() || rhs.IsDouble()) {
double a = GetDouble(); // May convert from integer to double.
double b = rhs.GetDouble(); // Ditto
return a >= b && a <= b; // Prevent -Wfloat-equal
}
else
return data_.n.u64 == rhs.data_.n.u64;
default:
return true;
}
}
//! Equal-to operator with const C-string pointer
bool operator==(const Ch* rhs) const { return *this == GenericValue(StringRef(rhs)); }
#if RAPIDJSON_HAS_STDSTRING
//! Equal-to operator with string object
/*! \note Requires the definition of the preprocessor symbol \ref RAPIDJSON_HAS_STDSTRING.
*/
bool operator==(const std::basic_string& rhs) const { return *this == GenericValue(StringRef(rhs)); }
#endif
//! Equal-to operator with primitive types
/*! \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t, \c double, \c true, \c false
*/
template RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr