resultFiles;
/**
* 微调作业的当前状态
*/
private String status;
/**
* 此微调作业处理的计费令牌总数
*/
@JsonProperty("trained_tokens")
private Integer trainedTokens;
/**
* 用于训练的文件 ID
*/
@JsonProperty("training_file")
private String trainingFile;
/**
* 用于验证的文件 ID
*/
@JsonProperty("validation_file")
private String validationFile;
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/images/ImageObject.java
================================================
package com.ai.openai.endPoint.images;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class ImageObject {
private String url;
@JsonProperty("b64_json")
private String b64Json;
@JsonProperty("revised_prompt")
private String revisedPrompt;
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/images/req/CreateImageRequest.java
================================================
package com.ai.openai.endPoint.images.req;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;
import java.io.Serializable;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class CreateImageRequest implements Serializable {
/**
* 所需图像的文本描述。的最大长度为 1000 个字符,4000 个字符。
*/
private String prompt;
/**
* 用于图像生成的模型
*
* 默认为 dall-e-2
*
* @see Model
*/
@Builder.Default
private String model = Model.DALL_E_2.getName();
/**
* 要生成的图像数。必须介于 1 和 10 之间,dall-e-3只能为1。
*/
@Builder.Default
private Integer n = 1;
/**
* 将生成的图像的质量。 创建具有更精细细节和更高一致性的图像。
*
* @see Quality
*/
private String quality;
/**
* 返回生成的图像的格式:url、b64_json
*/
@JsonProperty("response_format")
private String responseFormat;
/**
* 图片尺寸,默认值:1024x1024
* dall-e-2支持:256x256, 512x512, or 1024x1024
* dall-e-3支持:1024x1024, 1792x1024, or 1024x1792
*/
private String size;
/**
* 生成的图像的样式。
* 此参数仅仅dall-e-3,取值范围:vivid、natural
* 默认值:vivid
*
* @see Style
*/
private String style;
/**
* 代表最终用户的唯一标识符
*/
private String user;
/**
* 构建基础请求内容
*
* @param prompt 提示词
* @return 请求参数
*/
public static CreateImageRequest baseBuild(String prompt) {
return CreateImageRequest.builder().prompt(prompt).build();
}
/**
* 图片生成模型
*/
@Getter
@AllArgsConstructor
public enum Model {
DALL_E_2("dall-e-2"),
DALL_E_3("dall-e-3"),
;
private final String name;
}
/**
* 生成图片质量
*/
@Getter
@AllArgsConstructor
public enum Quality {
STANDARD("standard"),
HD("hd"),
;
private final String quality;
}
/**
* 生成图片风格
*/
@Getter
@AllArgsConstructor
public enum Style {
VIVID("vivid"),
NATURAL("natural"),
;
private final String style;
}
@Getter
@AllArgsConstructor
public enum Format {
URL("url"),
B64JSON("b64_json"),
;
private final String format;
}
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/images/req/ImageEditRequest.java
================================================
package com.ai.openai.endPoint.images.req;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;
import java.io.Serializable;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageEditRequest {
/**
* 必选项:描述文字,最多1000字符
*/
@NonNull
private String prompt;
/**
* 为每个提示生成的完成次数。
*/
@Builder.Default
private Integer n = 1;
/**
* 256x256
* 512x512
* 1024x1024
*/
@Builder.Default
private String size = SizeEnum.size_1024.getName();
@JsonProperty("response_format")
@Builder.Default
private String responseFormat = ResponseFormat.URL.getName();
private String user;
/**
* 构造基础请求参数
*
* @param prompt 提示词
* @return 请求参数
*/
public static ImageEditRequest baseBuild(String prompt) {
return ImageEditRequest.builder().prompt(prompt).build();
}
@Getter
@AllArgsConstructor
public enum SizeEnum implements Serializable {
size_1024_1792("1024x1792"),
size_1792_1024("1792x1024"),
size_1024("1024x1024"),
size_512("512x512"),
size_256("256x256"),
;
private final String name;
}
@AllArgsConstructor
@Getter
public enum ResponseFormat implements Serializable {
URL("url"),
B64_JSON("b64_json"),
;
private final String name;
}
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/images/req/ImageVariationRequest.java
================================================
package com.ai.openai.endPoint.images.req;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import java.io.Serializable;
@Slf4j
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageVariationRequest {
/**
* 为每个提示生成的完成次数。
*/
@Builder.Default
private Integer n = 1;
/**
* 256x256
* 512x512
* 1024x1024
*/
@Builder.Default
private String size = SizeEnum.size_1024.getName();
@JsonProperty("response_format")
@Builder.Default
private String responseFormat = ResponseFormat.URL.getName();
private String user;
@Getter
@AllArgsConstructor
public enum SizeEnum implements Serializable {
size_1024_1792("1024x1792"),
size_1792_1024("1792x1024"),
size_1024("1024x1024"),
size_512("512x512"),
size_256("256x256"),
;
private final String name;
}
@AllArgsConstructor
@Getter
public enum ResponseFormat implements Serializable {
URL("url"),
B64_JSON("b64_json"),
;
private final String name;
}
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/images/resp/CreateImageResponse.java
================================================
package com.ai.openai.endPoint.images.resp;
import com.ai.openai.endPoint.images.ImageObject;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class CreateImageResponse {
private Long created;
private List data;
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/models/ModelObject.java
================================================
package com.ai.openai.endPoint.models;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ModelObject {
/**
* 模型标识符,可在 API 端点中引用。
*/
private String id;
/**
* 对象类型,始终为“model”。
*/
private String object;
/**
* 创建模型时的 Unix 时间戳(以秒为单位)。
*/
private long created;
/**
* 拥有模型的组织。
*/
@JsonProperty("owned_by")
private String ownedBy;
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/models/resp/DeleteFineTuneModelResponse.java
================================================
package com.ai.openai.endPoint.models.resp;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class DeleteFineTuneModelResponse implements Serializable {
private String id;
private String object;
private Boolean deleted;
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/moderations/Categories.java
================================================
package com.ai.openai.endPoint.moderations;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class Categories implements Serializable {
/**
* 表达、煽动或宣扬基于种族、性别、民族、宗教、国籍、性取向、残疾状况或种姓的仇恨的内容。
*/
private boolean hate;
/**
* 仇恨内容,还包括对目标群体的暴力或严重伤害。
*/
@JsonProperty("hate/threatening")
private boolean hateThreatening;
/**
* 对任何目标表达、煽动或宣扬骚扰性语言的内容。
*/
@JsonProperty("harassment")
private boolean harassment;
/**
* 骚扰内容还包括对任何目标的暴力或严重伤害。
*/
@JsonProperty("harassment/threatening")
private boolean harassmentThreatening;
/**
* 宣扬、鼓励或描绘自残行为(例如自杀、割伤和饮食失调)的内容。
*/
@JsonProperty("self-harm")
private boolean selfHarm;
/**
* 说话者表示他们正在或打算进行自残行为的内容,例如自杀、割伤和饮食失调。
*/
@JsonProperty("self-harm/intent")
private boolean selfHarmIntent;
/**
* 鼓励进行自残行为(例如自杀、割伤和饮食失调)的内容,或者提供有关如何实施此类行为的说明或建议的内容。
*/
@JsonProperty("self-harm/instructions")
private boolean selfHarmInstructions;
/**
* 旨在引起性兴奋的内容,例如对性活动的描述,或宣传性服务(不包括性教育和健康)的内容。
*/
private boolean sexual;
/**
* 包含未满 18 周岁的个人的色情内容。
*/
@JsonProperty("sexual/minors")
private boolean sexualMinors;
/**
* 宣扬或美化暴力或歌颂他人遭受苦难或羞辱的内容。
*/
private boolean violence;
/**
* 以极端血腥细节描绘死亡、暴力或严重身体伤害的暴力内容。
*/
@JsonProperty("violence/graphic")
private boolean violenceGraphic;
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/moderations/CategoryScores.java
================================================
package com.ai.openai.endPoint.moderations;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
import java.math.BigDecimal;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class CategoryScores implements Serializable {
private BigDecimal hate;
@JsonProperty("hate/threatening")
private BigDecimal hateThreatening;
/**
* 对任何目标表达、煽动或宣扬骚扰性语言的内容。
*/
@JsonProperty("harassment")
private BigDecimal harassment;
/**
* 骚扰内容还包括对任何目标的暴力或严重伤害。
*/
@JsonProperty("harassment/threatening")
private BigDecimal harassmentThreatening;
@JsonProperty("self-harm")
private BigDecimal selfHarm;
/**
* 说话者表示他们正在或打算进行自残行为的内容,例如自杀、割伤和饮食失调。
*/
@JsonProperty("self-harm/intent")
private BigDecimal selfHarmIntent;
/**
* 鼓励进行自残行为(例如自杀、割伤和饮食失调)的内容,或者提供有关如何实施此类行为的说明或建议的内容。
*/
@JsonProperty("self-harm/instructions")
private BigDecimal selfHarmInstructions;
private BigDecimal sexual;
@JsonProperty("sexual/minors")
private BigDecimal sexualMinors;
private BigDecimal violence;
@JsonProperty("violence/graphic")
private BigDecimal violenceGraphic;
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/moderations/Result.java
================================================
package com.ai.openai.endPoint.moderations;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class Result implements Serializable {
/**
* 内容是否违反 OpenAI 的使用政策
*/
private boolean flagged;
/**
* 类别列表,以及它们是否被标记
*/
private Categories categories;
/**
* 类别列表及其按模型预测的分数
*/
@JsonProperty("category_scores")
private CategoryScores categoryScores;
/**
* 原文内容
*/
@JsonIgnore
private String content;
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/moderations/req/ModerationRequest.java
================================================
package com.ai.openai.endPoint.moderations.req;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.*;
import java.util.ArrayList;
import java.util.List;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ModerationRequest {
@NonNull
private List input;
@Builder.Default
private String model = Model.TEXT_MODERATION_LATEST.getName();
/**
* 构造基础请求参数
*
* @param input 文本
* @return 请求参数
*/
public static ModerationRequest baseBuild(String input) {
ArrayList list = new ArrayList<>();
list.add(input);
return baseBuild(list);
}
/**
* 构造基础请求参数
*
* @param inputList 文本数组
* @return 请求参数
*/
public static ModerationRequest baseBuild(List inputList) {
return ModerationRequest.builder().input(inputList).build();
}
@Getter
@AllArgsConstructor
public enum Model {
TEXT_MODERATION_STABLE("text-moderation-stable"),
TEXT_MODERATION_LATEST("text-moderation-latest"),
;
private final String name;
}
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/endPoint/moderations/resp/ModerationResponse.java
================================================
package com.ai.openai.endPoint.moderations.resp;
import com.ai.openai.endPoint.moderations.Result;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ModerationResponse {
/**
* 审核请求的唯一标识符
*/
private String id;
/**
* 用于生成审核结果的模型
*/
private String model;
/**
* 审核对象的列表
*/
private List results;
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/interceptor/HeaderInterceptor.java
================================================
package com.ai.openai.interceptor;
import cn.hutool.http.ContentType;
import cn.hutool.http.Header;
import com.ai.core.strategy.KeyStrategy;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import org.jetbrains.annotations.NotNull;
import java.io.IOException;
import java.util.List;
import static com.ai.core.exception.Constants.*;
/**
* @Description: OpenAI Key 拦截器
**/
public class HeaderInterceptor implements Interceptor {
/**
* 系统设置的 openAI apiKey
*/
private final List apiKeyBySystem;
/**
* 系统设置的 api 请求地址
*/
private final String apiHostBySystem;
/**
* key 获取策略
*/
private KeyStrategy, String> keyStrategy;
public HeaderInterceptor(List apiKeyBySystem, String apiHostBySystem, KeyStrategy keyStrategy) {
this.apiKeyBySystem = apiKeyBySystem;
this.apiHostBySystem = apiHostBySystem;
this.keyStrategy = keyStrategy;
}
@NotNull
@Override
public Response intercept(Chain chain) throws IOException {
// 1. 获取原始 Request
Request originalReq = chain.request();
// 2. 读取 apiKey;优先使用用户传递的 apiKey
String apiKeyByUser = originalReq.header(API_KEY);
String apiHostByUser = originalReq.header(API_HOST);
String apiUrlByUser = originalReq.header(URL);
String apiKey = apiKeyByUser == NULL ? keyStrategy.apply(apiKeyBySystem) : apiKeyByUser;
// 3. 读取 apiUrl 和 apiHost,apiUrl 优先级大于 apiHost
String apiUrl = apiUrlByUser == NULL ? apiHostByUser == NULL ? String.valueOf(originalReq.url()) : apiHostByUser + originalReq.url().url().getPath() : apiUrlByUser;
// 4. 构建 Request
Request request = originalReq.newBuilder()
.url(apiUrl)
.header(Header.AUTHORIZATION.getValue(), "Bearer " + apiKey)
.header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue())
.method(originalReq.method(), originalReq.body())
.build();
// 5. 返回执行结果
return chain.proceed(request);
}
}
================================================
FILE: ai-openai/src/main/java/com/ai/openai/interceptor/ResponseInterceptor.java
================================================
package com.ai.openai.interceptor;
import cn.hutool.json.JSONUtil;
import com.ai.core.exception.BaseException;
import com.ai.core.exception.Constants;
import com.ai.openai.common.CommonListResponse;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import java.io.IOException;
import java.util.Objects;
@Slf4j
public class ResponseInterceptor implements Interceptor {
@Override
public Response intercept(Chain chain) throws IOException {
// 1. 获取 req 和 resp
Request original = chain.request();
Response response = chain.proceed(original);
// 2. 判断返回状态
if (!response.isSuccessful() && response.body() != null) {
// 2.1 获取返回的错误信息
String errorMsg = response.body().string();
CommonListResponse openAiResponse = JSONUtil.toBean(errorMsg, CommonListResponse.class);
if (Constants.ERROR_MSG_MAP.containsKey(response.code())) {
log.error(openAiResponse.getError().getMessage());
throw new BaseException(openAiResponse.getError().getMessage());
}
log.error("--------> 请求异常:{}", errorMsg);
if (Objects.nonNull(openAiResponse.getError())) {
log.error(openAiResponse.getError().getMessage());
throw new BaseException(openAiResponse.getError().getMessage());
}
throw new BaseException(Constants.ErrorMsg.RETRY_ERROR);
}
return response;
}
}
================================================
FILE: ai-openai/src/test/java/com/ai/openai/AudioApiTest.java
================================================
package com.ai.openai;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.openai.achieve.Configuration;
import com.ai.openai.achieve.defaults.DefaultOpenAiSessionFactory;
import com.ai.openai.achieve.standard.session.AggregationSession;
import com.ai.openai.endPoint.audio.req.SttCompletionRequest;
import com.ai.openai.endPoint.audio.req.TtsCompletionRequest;
import com.ai.openai.endPoint.audio.resp.SttCompletionResponse;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import org.junit.Before;
import org.junit.Test;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Response;
import java.io.*;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import static com.ai.core.exception.Constants.NULL;
/**
* @Description: 测试语音相关接口功能
**/
@Slf4j
public class AudioApiTest {
private AggregationSession aggregationSession;
@Before
public void test_OpenAiSessionFactory() {
// 1. 创建配置类
Configuration configuration = new Configuration();
// 2. 设置请求地址,若有代理商或者代理服务器,可填写为代理服务器的请求路径
configuration.setApiHost("https://api.openai.com");
// 3. 设置鉴权所需的API Key,可设置多个。
configuration.setKeyList(Arrays.asList("**************************"));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultOpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试文字转语音
*/
@Test
public void test_tts() throws InterruptedException {
// 定义请求参数
TtsCompletionRequest ttsCompletionRequest = TtsCompletionRequest.builder()
.model(TtsCompletionRequest.Model.tts_1.getModuleName())// 设置使用的模型
.input("你好,我是chatGPT")
.voice(TtsCompletionRequest.Voice.alloy.getVoiceName())// 设置声音的样式
.build();
// 回传文件存放的路径
File file = new File("doc/test/test_tts.mp3");
// 添加回调函数,发送请求
aggregationSession.getAudioSession().ttsCompletions(NULL, NULL, NULL, ttsCompletionRequest, new Callback() {
@Override
public void onResponse(Call call, Response response) {
try (InputStream inputStream = response.body().byteStream();
OutputStream os = new BufferedOutputStream(new FileOutputStream(file))) {
// 创建文件
if (!file.exists()) {
if (!file.getParentFile().exists()) file.getParentFile().mkdir();
file.createNewFile();
}
byte data[] = new byte[10240];
int len;
while ((len = inputStream.read(data, 0, 10240)) != -1) {
os.write(data, 0, len);
}
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void onFailure(Call call, Throwable t) {
t.printStackTrace();
}
}
);
// 阻塞等待
new CountDownLatch(1).await();
}
/**
* 测试语音转文字
*/
@Test
public void test_stt() {
// 音频文件存放路径
File file = new File("doc/test/test_tts.mp3");
SttCompletionRequest sttCompletionRequest = SttCompletionRequest.builder().file(file).build();
SttCompletionResponse sttCompletionResponse = this.aggregationSession.getAudioSession().sttCompletions(NULL, NULL, NULL, sttCompletionRequest);
log.info("测试结果:{}", sttCompletionResponse);
}
/**
* 测试音频文件转文字后翻译为英文
*/
@Test
public void test_translation() {
// 音频文件存放路径
File file = new File("doc/test/test_tts.mp3");
SttCompletionRequest sttCompletionRequest = SttCompletionRequest.builder().file(file).build();
SttCompletionResponse sttCompletionResponse = this.aggregationSession.getAudioSession().translationCompletions(NULL, NULL, NULL, sttCompletionRequest);
log.info("测试结果:{}", sttCompletionResponse);
}
}
================================================
FILE: ai-openai/src/test/java/com/ai/openai/ChatApiTest.java
================================================
package com.ai.openai;
import cn.hutool.json.JSONObject;
import com.ai.core.exception.Constants;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.openai.achieve.Configuration;
import com.ai.openai.achieve.defaults.DefaultOpenAiSessionFactory;
import com.ai.openai.achieve.standard.session.AggregationSession;
import com.ai.openai.endPoint.chat.Parameters;
import com.ai.openai.endPoint.chat.msg.Content;
import com.ai.openai.endPoint.chat.msg.DefaultMessage;
import com.ai.openai.endPoint.chat.msg.ImgMessage;
import com.ai.openai.endPoint.chat.req.*;
import com.ai.openai.endPoint.chat.resp.ChatCompletionResponse;
import com.ai.openai.endPoint.chat.resp.QaCompletionResponse;
import com.ai.openai.endPoint.chat.tools.Tool;
import com.ai.openai.endPoint.chat.tools.ToolFunction;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.junit.Before;
import org.junit.Test;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.Arrays;
import java.util.Collections;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import static com.ai.core.exception.Constants.NULL;
/**
* @Description: 测试聊天接口相关接口功能
*/
@Slf4j
public class ChatApiTest {
private AggregationSession aggregationSession;
@Before
public void test_OpenAiSessionFactory() {
// 1. 创建配置类
Configuration configuration = new Configuration();
// 2. 设置请求地址,若有代理商或者代理服务器,可填写为代理服务器的请求路径
configuration.setApiHost("https://api.openai.com");
// 3. 设置鉴权所需的API Key,可设置多个。
configuration.setKeyList(Arrays.asList("填入你的API Key"));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultOpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试简单问答,相当于只有一轮问答。
*/
@Test
public void test_qa_completions() {
QaCompletionRequest qaCompletionRequest = QaCompletionRequest.baseBuild("你是谁?");
QaCompletionResponse qaCompletionResponse = aggregationSession.getChatSession().qaCompletions(NULL, NULL, NULL, qaCompletionRequest);
log.info("测试结果:{}", qaCompletionResponse);
}
/**
* 测试简单问答,流式返回结果。
*/
@Test
public void test_qa_completions_stream() throws InterruptedException, JsonProcessingException {
QaCompletionRequest qaCompletionRequest = QaCompletionRequest.builder()
.prompt("讲一个笑话")
.stream(true) // 设置流式返回
.build();
// 监听器监听返回的结果
aggregationSession.getChatSession().qaCompletions(NULL, NULL, NULL, qaCompletionRequest, new EventSourceListener() {
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("测试结果 id:{} type:{} data:{}", id, type, data);
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
log.error("失败 code:{} message:{}", response.code(), response.message());
}
});
// 阻塞等待
new CountDownLatch(1).await();
}
/**
* 测试多轮对话
*/
@Test
public void test_chat_completions() {
// 创建参数,上下文对话。
// 第一次的问题
DefaultChatCompletionRequest defaultChatCompletionRequest = DefaultChatCompletionRequest.baseBuild("1+1=");
// 第一次的回复
defaultChatCompletionRequest.addMessage(Constants.Role.ASSISTANT.getRoleName(), "2");
// 第二次的问题
defaultChatCompletionRequest.addMessage(Constants.Role.USER.getRoleName(), "2+2=");
// 询问第二次的问题的结果
ChatCompletionResponse chatCompletionResponse = aggregationSession.getChatSession().chatCompletions(NULL, NULL, NULL, defaultChatCompletionRequest);
// 解析结果
chatCompletionResponse.getChoices().forEach(e -> {
log.info("测试结果:{}", e.getMessage());
});
}
/**
* 测试函数对话,创建一个函数获取天气信息
* 下面的请求参数根据官方案例转换而来
*/
@Test
public void test_func_chat_completions() {
// 定义第一个属性,地址信息
JSONObject location = new JSONObject();
location.putOpt("type", "string");
location.putOpt("description", "The city and state, e.g. San Francisco, CA");
// 定义第二个属性,时间信息
JSONObject unit = new JSONObject();
unit.putOpt("type", "string");
unit.putOpt("enum", Arrays.asList("celsius", "fahrenheit"));
// 定义 properties,及将函数属性组合起来
JSONObject properties = new JSONObject();
properties.putOpt("location", location);
properties.putOpt("unit", unit);
// 定义 parameters
Parameters parameters = Parameters.builder().type("object").properties(properties).required(Arrays.asList("location")).build();
// 构造函数信息
ToolFunction toolFunction = ToolFunction.builder().name("get_current_weather").description("Get the current weather in a given location").parameters(parameters).build();
// 构造工具
Tool tool = Tool.builder().type(Tool.Type.FUNCTION.getName()).function(toolFunction).build();
// 构造请求参数
FuncChatCompletionRequest funcChatCompletionRequest = FuncChatCompletionRequest.baseBuild("What is the weather like in Boston?");
funcChatCompletionRequest.setTools(Arrays.asList(tool));
funcChatCompletionRequest.setToolChoice("auto");
// 获取请求结果
ChatCompletionResponse chatCompletionResponse = aggregationSession.getChatSession().chatCompletions(NULL, NULL, NULL, funcChatCompletionRequest);
log.info("测试结果:{}", chatCompletionResponse);
}
/**
* 测试图片对话,需要GPT4权限
*/
@Test
public void test_img_chat_completions() {
// 构造对话内容
Content textContent = Content.BuildTextContent("这张图片当中有什么?");
Content imgContent = Content.BuildImageUrlContent("https://oaidalleapiprodscus.blob.core.windows.net/private/org-HL3RbCOW1GSH0YPFWak9m6be/user-AzfIxMQpzvc9raSA0TZ9sHOw/img-oft6JFrL4ilB4mmapSud8Vpy.png?st=2024-03-08T13%3A31%3A32Z&se=2024-03-08T15%3A31%3A32Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-03-07T18%3A11%3A55Z&ske=2024-03-08T18%3A11%3A55Z&sks=b&skv=2021-08-06&sig=wLcx9Uo6XiYT10GXIFcJwfd08BKoJ07k7cP7qJvYGx4%3D");
// 构造 msg 内容
ImgMessage imgMessage = ImgMessage.builder().role(Constants.Role.USER.getRoleName()).content(Arrays.asList(textContent, imgContent)).build();
// 构造请求参数,chatGPT 4 支持图片对话
ImgChatCompletionRequest imgChatCompletionRequest = ImgChatCompletionRequest.builder().model(BaseChatCompletionRequest.Model.GPT_4_VISION_PREVIEW.getModuleName()).messages(Arrays.asList(imgMessage)).build();
// 获取结果
ChatCompletionResponse chatCompletionResponse = aggregationSession.getChatSession().chatCompletions(NULL, NULL, NULL, imgChatCompletionRequest);
log.info("测试结果:{}", chatCompletionResponse);
}
/**
* 测试聊天对话,流式返回结果。
*/
@Test
public void test_chat_completions_stream() throws InterruptedException, JsonProcessingException {
// 建造者模式构造参数
DefaultChatCompletionRequest defaultChatCompletionRequest = DefaultChatCompletionRequest.builder()
.stream(true)// 开启流式返回
.messages(Collections.singletonList(DefaultMessage
.builder()
.role(Constants.Role.USER.getRoleName())
.content("讲一个笑话")
.build()))
.model(BaseChatCompletionRequest.Model.GPT_3_5_TURBO.getModuleName())
.build();
aggregationSession.getChatSession().chatCompletions(NULL, NULL, NULL, defaultChatCompletionRequest, new EventSourceListener() {
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("测试结果 id:{} type:{} data:{}", id, type, data);
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
log.error("失败 code:{} message:{}", response.code(), response.message());
}
});
// 阻塞等待
new CountDownLatch(1).await();
}
/**
* 测试 CompletableFuture 异步调用。
*/
@Test
public void test_chat_completions_future() throws JsonProcessingException, InterruptedException, ExecutionException {
// 构造请求参数
DefaultChatCompletionRequest defaultChatCompletionRequest = DefaultChatCompletionRequest.builder()
.stream(true)
.messages(Collections.singletonList(DefaultMessage
.builder()
.role(Constants.Role.USER.getRoleName())
.content("1+1=")
.build()))
.model(BaseChatCompletionRequest.Model.GPT_3_5_TURBO.getModuleName()).build();
// 等待结果
CompletableFuture future = aggregationSession.getChatSession().chatCompletionsFuture(NULL, NULL, NULL, defaultChatCompletionRequest);
log.info("测试结果:{}", future.get());
}
}
================================================
FILE: ai-openai/src/test/java/com/ai/openai/EmbeddingApiTest.java
================================================
package com.ai.openai;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.openai.achieve.Configuration;
import com.ai.openai.achieve.defaults.DefaultOpenAiSessionFactory;
import com.ai.openai.achieve.standard.session.AggregationSession;
import com.ai.openai.endPoint.embeddings.EmbeddingObject;
import com.ai.openai.endPoint.embeddings.req.EmbeddingCompletionRequest;
import com.ai.openai.endPoint.embeddings.resp.EmbeddingCompletionResponse;
import lombok.extern.slf4j.Slf4j;
import org.junit.Before;
import org.junit.Test;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static com.ai.core.exception.Constants.NULL;
/**
* @Description: 测试嵌入相关接口功能
*/
@Slf4j
public class EmbeddingApiTest {
private AggregationSession aggregationSession;
@Before
public void test_OpenAiSessionFactory() {
// 1. 创建配置类
Configuration configuration = new Configuration();
// 2. 设置请求地址,若有代理商或者代理服务器,可填写为代理服务器的请求路径
configuration.setApiHost("https://api.openai.com");
// 3. 设置鉴权所需的API Key,可设置多个。
configuration.setKeyList(Arrays.asList("**************************"));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultOpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试单个文本
*/
@Test
public void test_embedding() {
EmbeddingCompletionResponse embeddingCompletionResponse = aggregationSession.getEmbeddingSession().embeddingCompletions(NULL, NULL, NULL, "你好");
for (EmbeddingObject embeddingObject : embeddingCompletionResponse.getData()) {
System.out.println(embeddingObject.getEmbedding().length);
}
log.info("返回结果:{}", embeddingCompletionResponse);
log.info("返回结果:{}", embeddingCompletionResponse.getData().size());
}
/**
* 测试多个文本嵌入
*/
@Test
public void test_embedding_list() {
List inputList = new ArrayList<>();
inputList.add("你好");
inputList.add("世界");
EmbeddingCompletionResponse embeddingCompletionResponse = aggregationSession.getEmbeddingSession().embeddingCompletions(NULL, NULL, NULL, inputList);
log.info("返回结果:{}", embeddingCompletionResponse);
}
@Test
public void test_embedding_req() {
List inputList = new ArrayList<>();
inputList.add("你好");
inputList.add("世界");
EmbeddingCompletionRequest embeddingCompletionRequest = EmbeddingCompletionRequest.baseBuild(inputList);
EmbeddingCompletionResponse embeddingCompletionResponse = aggregationSession.getEmbeddingSession().embeddingCompletions(NULL, NULL, NULL, embeddingCompletionRequest);
log.info("返回结果:{}", embeddingCompletionResponse);
}
}
================================================
FILE: ai-openai/src/test/java/com/ai/openai/FilesApiTest.java
================================================
package com.ai.openai;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.openai.achieve.Configuration;
import com.ai.openai.achieve.defaults.DefaultOpenAiSessionFactory;
import com.ai.openai.achieve.standard.session.AggregationSession;
import com.ai.openai.endPoint.files.FileObject;
import com.ai.openai.endPoint.files.resp.DeleteFileResponse;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.Arrays;
import java.util.List;
import static com.ai.core.exception.Constants.NULL;
/**
* @Description: 测试文件相关接口功能
**/
@Slf4j
public class FilesApiTest {
private AggregationSession aggregationSession;
@Before
public void test_OpenAiSessionFactory() {
// 1. 创建配置类
Configuration configuration = new Configuration();
// 2. 设置请求地址,若有代理商或者代理服务器,可填写为代理服务器的请求路径
configuration.setApiHost("https://api.openai.com");
// 3. 设置鉴权所需的API Key,可设置多个。
configuration.setKeyList(Arrays.asList("填入你的API Key"));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultOpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试列出文件
*/
@Test
public void test_list_file() {
List fileObjects = aggregationSession.getFilesSession().listFilesCompletions(NULL, NULL, NULL);
log.info("测试结果:{}", fileObjects);
}
/**
* 测试上传文件
*/
@Test
public void test_upload_file() {
File file = new File("src/main/resources/test1.txt");
FileObject fileObject = aggregationSession.getFilesSession().uploadFileCompletions(NULL, NULL, NULL, file, "fine-tune");
log.info("测试结果:{}", fileObject);
}
/**
* 测试删除文件
*/
@Test
public void test_delete_file() {
DeleteFileResponse deleteFileResponse = aggregationSession.getFilesSession().deleteFileCompletions(NULL, NULL, NULL, "file-B3CAfSS2ibv7cFmSbl5m1CPI");
log.info("测试结果:{}", deleteFileResponse);
}
/**
* 测试检索文件
*/
@Test
public void test_retrieve_file() {
FileObject fileObject = aggregationSession.getFilesSession().retrieveFileCompletions(NULL, NULL, NULL, "file-B3CAfSS2ibv7cFmSbl5m1CPI");
log.info("测试结果:{}", fileObject);
}
/**
* 测试获取检索的文件内容
*/
@Test
public void test_retrieve_file_context() {
ResponseBody responseBody = aggregationSession.getFilesSession().retrieveFileContextCompletions(NULL, NULL, NULL, "file-B3CAfSS2ibv7cFmSbl5m1CPI");
log.info("测试结果:{}", responseBody);
}
}
================================================
FILE: ai-openai/src/test/java/com/ai/openai/FineTuningApiTest.java
================================================
package com.ai.openai;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.openai.achieve.Configuration;
import com.ai.openai.achieve.defaults.DefaultOpenAiSessionFactory;
import com.ai.openai.achieve.standard.session.AggregationSession;
import com.ai.openai.common.CommonListResponse;
import com.ai.openai.endPoint.fineTuning.FineTuningEvent;
import com.ai.openai.endPoint.fineTuning.req.FineTuningRequest;
import com.ai.openai.endPoint.fineTuning.req.ListFineTuningRequest;
import com.ai.openai.endPoint.fineTuning.resp.FineTuningResponse;
import lombok.extern.slf4j.Slf4j;
import org.junit.Before;
import org.junit.Test;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.Arrays;
import static com.ai.core.exception.Constants.NULL;
/**
* @Description: 测试微调相关接口功能
**/
@Slf4j
public class FineTuningApiTest {
private AggregationSession aggregationSession;
@Before
public void test_OpenAiSessionFactory() {
// 1. 创建配置类
Configuration configuration = new Configuration();
// 2. 设置请求地址,若有代理商或者代理服务器,可填写为代理服务器的请求路径
configuration.setApiHost("https://api.openai.com");
// 3. 设置鉴权所需的API Key,可设置多个。
configuration.setKeyList(Arrays.asList("填入你的API Key"));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultOpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试创建微调
*/
@Test
public void test_create_fine_tuning() {
FineTuningRequest fineTuningRequest = FineTuningRequest
.builder()
.trainingFile("file-BK7bzQj3FfZFXr7DbL6xJwfo")
.build();
FineTuningResponse fineTuningResponse = aggregationSession.getFineTuningSession().createFineTuningJobCompletions(NULL, NULL, NULL, fineTuningRequest);
log.info("返回结果:{}", fineTuningResponse);
}
/**
* 测试列出微调作业
*/
@Test
public void test_list_fine_tuning_1() {
ListFineTuningRequest listFineTuningRequest = ListFineTuningRequest.builder().build();
CommonListResponse fineTuneJobListFineTuningResponse = aggregationSession.getFineTuningSession().listFineTuningJobsCompletions(NULL, NULL, NULL, listFineTuningRequest);
log.info("返回结果:{}", fineTuneJobListFineTuningResponse);
}
/**
* 测试列出微调作业
*/
@Test
public void test_list_fine_tuning_2() {
CommonListResponse fineTuneJobListFineTuningResponse = aggregationSession.getFineTuningSession().listFineTuningJobsCompletions(NULL, NULL, NULL, null, null);
log.info("返回结果:{}", fineTuneJobListFineTuningResponse);
}
/**
* 测试检索微调作业
*/
@Test
public void test_retrieve_fine_tuning() {
FineTuningResponse fineTuningResponse = aggregationSession.getFineTuningSession().retrieveFineTuningJobCompletions(NULL, NULL, NULL, "ft-AF1WoRqd3aJAHsqc9NY7iL8F");
log.info("返回结果:{}", fineTuningResponse);
}
/**
* 测试关闭微调作业
*/
@Test
public void test_cancel_fine_tuning() {
FineTuningResponse fineTuningResponse = aggregationSession.getFineTuningSession().cancelFineTuningJobCompletions(NULL, NULL, NULL, "ft-AF1WoRqd3aJAHsqc9NY7iL8F");
log.info("返回结果:{}", fineTuningResponse);
}
/**
* 测试列出微调事件
*/
@Test
public void test_list_fine_tuning_events() {
CommonListResponse fineTuningEventListFineTuningResponse = aggregationSession.getFineTuningSession().listFineTuningEventsCompletions(NULL, NULL, NULL, "ftjob-abc123");
log.info("返回结果:{}", fineTuningEventListFineTuningResponse);
}
}
================================================
FILE: ai-openai/src/test/java/com/ai/openai/ImageApiTest.java
================================================
package com.ai.openai;
import com.ai.common.utils.ImageUtils;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.openai.achieve.Configuration;
import com.ai.openai.achieve.defaults.DefaultOpenAiSessionFactory;
import com.ai.openai.achieve.standard.session.AggregationSession;
import com.ai.openai.endPoint.images.ImageObject;
import com.ai.openai.endPoint.images.req.CreateImageRequest;
import com.ai.openai.endPoint.images.req.ImageEditRequest;
import com.ai.openai.endPoint.images.req.ImageVariationRequest;
import lombok.extern.slf4j.Slf4j;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.Arrays;
import java.util.List;
import static com.ai.core.exception.Constants.NULL;
/**
* @Description: 测试图片相关接口功能
**/
@Slf4j
public class ImageApiTest {
private AggregationSession aggregationSession;
@Before
public void test_OpenAiSessionFactory() {
// 1. 创建配置类
Configuration configuration = new Configuration();
// 2. 设置请求地址,若有代理商或者代理服务器,可填写为代理服务器的请求路径
configuration.setApiHost("https://api.openai.com");
// 3. 设置鉴权所需的API Key,可设置多个。
configuration.setKeyList(Arrays.asList("填入你的API Key"));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultOpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试图片生成,返回B64Json格式
*/
@Test
public void test_create_image_b64json() {
CreateImageRequest createImageRequest = CreateImageRequest.baseBuild("画一个花园,花园里面有蝴蝶。");
createImageRequest.setResponseFormat(CreateImageRequest.Format.B64JSON.getFormat());
List imageObjectList = aggregationSession.getImageSession().createImageCompletions(NULL, NULL, NULL, createImageRequest);
for (int i = 0; i < imageObjectList.size(); i++) {
ImageUtils.convertBase64StrToImage(imageObjectList.get(i).getB64Json(), "D:\\chatGPT-api\\AI-java\\doc\\test\\test_openai_create_image_" + i + ".png");
}
}
/**
* 测试图片生成,返回URL格式
*/
@Test
public void test_create_image_url() {
CreateImageRequest createImageRequest = CreateImageRequest.baseBuild("画一个花园,花园里面有蝴蝶。");
createImageRequest.setResponseFormat(CreateImageRequest.Format.URL.getFormat());
List imageObjectList = aggregationSession.getImageSession().createImageCompletions(NULL, NULL, NULL, createImageRequest);
for (int i = 0; i < imageObjectList.size(); i++) {
System.out.println(imageObjectList.get(i).getUrl());
}
}
/**
* 测试编辑图片
*/
@Test
public void test_edit_image() {
File file = new File("D:\\chatGPT-api\\AI-java\\doc\\test\\test_edit_image.png");
ImageEditRequest imageEditRequest = ImageEditRequest.baseBuild("给小熊的背后加上一只梅花鹿。");
List imageObjectList = aggregationSession.getImageSession().editImageCompletions(NULL, NULL, NULL, file, null, imageEditRequest);
log.info("测试结果:{}", imageObjectList);
}
/**
* 测试创建图片变体
*/
@Test
public void test_variation_image() {
File file = new File("D:\\chatGPT-api\\AI-java\\doc\\test\\test_edit_image.png");
ImageVariationRequest imageVariationRequest = ImageVariationRequest.builder().build();
List imageObjectList = aggregationSession.getImageSession().variationImageCompletions(NULL, NULL, NULL, file, imageVariationRequest);
log.info("测试结果:{}", imageObjectList);
}
}
================================================
FILE: ai-openai/src/test/java/com/ai/openai/ModelApiTest.java
================================================
package com.ai.openai;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.openai.achieve.Configuration;
import com.ai.openai.achieve.defaults.DefaultOpenAiSessionFactory;
import com.ai.openai.achieve.standard.session.AggregationSession;
import com.ai.openai.endPoint.models.ModelObject;
import com.ai.openai.endPoint.models.resp.DeleteFineTuneModelResponse;
import lombok.extern.slf4j.Slf4j;
import org.junit.Before;
import org.junit.Test;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.Arrays;
import java.util.List;
import static com.ai.core.exception.Constants.NULL;
/**
* @Description: 测试模型相关接口功能
**/
@Slf4j
public class ModelApiTest {
private AggregationSession aggregationSession;
@Before
public void test_OpenAiSessionFactory() {
// 1. 创建配置类
Configuration configuration = new Configuration();
// 2. 设置请求地址,若有代理商或者代理服务器,可填写为代理服务器的请求路径
configuration.setApiHost("https://api.openai.com");
// 3. 设置鉴权所需的API Key,可设置多个。
configuration.setKeyList(Arrays.asList("填入你的API Key"));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultOpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试列出模型接口
*/
@Test
public void test_list_model() {
List modelObjects = aggregationSession.getModelSession().listModelCompletions(NULL, NULL, NULL);
log.info("返回结果:{}", modelObjects);
}
/**
* 测试检索模型接口
*/
@Test
public void test_retrieve_model() {
ModelObject modelObject = aggregationSession.getModelSession().retrieveModelCompletions(NULL, NULL, NULL, "gpt-3.5-turbo-instruct");
log.info("返回结果:{}", modelObject);
}
/**
* 测试删除微调模型接口
*/
@Test
public void test_delete_fine_tune_model() {
DeleteFineTuneModelResponse deleteFineTuneModelResponse = aggregationSession.getModelSession().deleteFineTuneModelCompletions(NULL, NULL, NULL, "gpt-3.5-turbo-instruct");
log.info("返回结果:{}", deleteFineTuneModelResponse);
}
}
================================================
FILE: ai-openai/src/test/java/com/ai/openai/ModerationApiTest.java
================================================
package com.ai.openai;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.openai.achieve.Configuration;
import com.ai.openai.achieve.defaults.DefaultOpenAiSessionFactory;
import com.ai.openai.achieve.standard.session.AggregationSession;
import com.ai.openai.endPoint.moderations.Result;
import com.ai.openai.endPoint.moderations.req.ModerationRequest;
import com.ai.openai.endPoint.moderations.resp.ModerationResponse;
import lombok.extern.slf4j.Slf4j;
import org.junit.Before;
import org.junit.Test;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import static com.ai.core.exception.Constants.NULL;
/**
* @Description: 测试审核相关接口功能
**/
@Slf4j
public class ModerationApiTest {
private AggregationSession aggregationSession;
@Before
public void test_OpenAiSessionFactory() {
// 1. 创建配置类
Configuration configuration = new Configuration();
// 2. 设置请求地址,若有代理商或者代理服务器,可填写为代理服务器的请求路径
configuration.setApiHost("https://api.openai.com");
// 3. 设置鉴权所需的API Key,可设置多个。
configuration.setKeyList(Arrays.asList("**************************"));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultOpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试审核
*/
@Test
public void test_moderation() {
ArrayList list = new ArrayList<>();
list.add("你好");
list.add("我要杀了你");
ModerationRequest moderationRequest = ModerationRequest.builder().input(list).build();
ModerationResponse moderationResponse = aggregationSession.getModerationSession().moderationCompletions(NULL, NULL, NULL, moderationRequest);
for (Result result : moderationResponse.getResults()) {
System.out.println(result.getContent() + " " + result.isFlagged());
}
}
}
================================================
FILE: ai-spark/pom.xml
================================================
AI-java
com.ai
1.0
4.0.0
ai-spark
UTF-8
UTF-8
1.8
1.8
1.8
4.13.2
com.ai
ai-common
1.0
com.ai
ai-core
1.0
junit
junit
${junit.version}
test
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/ApiData.java
================================================
package com.ai.spark.achieve;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* 记录用户API信息
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ApiData {
private String appId;
private String apiKey;
private String apiSecret;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/Configuration.java
================================================
package com.ai.spark.achieve;
import com.ai.core.config.BaseConfiguration;
import com.ai.spark.achieve.standard.api.SparkApiServer;
import lombok.*;
import org.jetbrains.annotations.NotNull;
import java.util.List;
/**
* @Description: 配置信息
**/
@Data
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@NoArgsConstructor
@AllArgsConstructor
public class Configuration extends BaseConfiguration {
/**
* api 服务提供者
*/
private SparkApiServer sparkApiServer;
/**
* api Key 集合
*/
@NotNull
private List keyList;
public ApiData getSystemApiData() {
return (ApiData) this.getKeyStrategy().apply(keyList);
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/DefaultSparkSessionFactory.java
================================================
package com.ai.spark.achieve.defaults;
import com.ai.core.factory.SessionFactory;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.defaults.session.DefaultAggregationSession;
import com.ai.spark.achieve.standard.api.SparkApiServer;
import com.ai.spark.achieve.standard.session.AggregationSession;
import com.ai.spark.interceptor.BaseUrlInterceptor;
import com.ai.spark.interceptor.ResponseInterceptor;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import retrofit2.Retrofit;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
import java.util.concurrent.TimeUnit;
import static com.ai.common.utils.ValidationUtils.ensureNotNull;
public class DefaultSparkSessionFactory implements SessionFactory {
private final Configuration configuration;
public DefaultSparkSessionFactory(Configuration configuration) {
this.configuration = ensureNotNull(configuration, "configuration");
}
@Override
public OkHttpClient createHttpClient() {
// 1. 日志配置
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor();
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.NONE);
// 2. 开启 Http 客户端
OkHttpClient.Builder builder = new OkHttpClient.Builder()
.addInterceptor(httpLoggingInterceptor)
.addInterceptor(new BaseUrlInterceptor())
.addInterceptor(new ResponseInterceptor())// 设置返回信息拦截器
.connectTimeout(450, TimeUnit.SECONDS)
.writeTimeout(450, TimeUnit.SECONDS)
.readTimeout(450, TimeUnit.SECONDS);
// 3. 检查是否需要代理
if (configuration.getProxy() != null) {
builder.proxy(configuration.getProxy());
}
return builder.build();
}
@Override
public SparkApiServer createApiServer(OkHttpClient okHttpClient) {
return new Retrofit.Builder()
.baseUrl(configuration.getApiHost())
.client(okHttpClient)
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.addConverterFactory(JacksonConverterFactory.create())
.build().create(SparkApiServer.class);
}
@Override
public AggregationSession openAggregationSession() {
OkHttpClient okHttpClient = createHttpClient();
configuration.setOkHttpClient(okHttpClient);
configuration.setSparkApiServer(createApiServer(okHttpClient));
return new DefaultAggregationSession(configuration);
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/listener/BaseListener.java
================================================
package com.ai.spark.achieve.defaults.listener;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
import okio.ByteString;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@Slf4j
@Data
public abstract class BaseListener extends WebSocketListener {
/**
* WebSocket服务发生异常的回调,可以覆盖重写。
* 默认抛出异常
*
* @param t 异常
* @param response 返回值
*/
public void onWebSocketError(Throwable t, Response response) {
log.error("调用星火模型时,WebSocket发生异常:{}", response);
t.printStackTrace();
}
/**
* 星火大模型发生异常
*
* @param resp 大模型返回值
*/
public abstract void onChatError(RESP resp);
/**
* 星火大模型正常返回信息
*
* @param resp 大模型返回值
*/
public abstract void onChatOutput(RESP resp);
/**
* 星火大模型返回信息结束回调
*/
public abstract void onChatEnd();
@Override
public void onClosed(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
super.onClosed(webSocket, code, reason);
}
@Override
public void onClosing(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
super.onClosing(webSocket, code, reason);
}
@Override
public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, @Nullable Response response) {
webSocket.close(1000, "");
this.onWebSocketError(t, response);
}
@Override
public abstract void onMessage(@NotNull WebSocket webSocket, @NotNull String text);
@Override
public void onMessage(@NotNull WebSocket webSocket, @NotNull ByteString bytes) {
super.onMessage(webSocket, bytes);
}
@Override
public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
super.onOpen(webSocket, response);
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/listener/ChatListener.java
================================================
package com.ai.spark.achieve.defaults.listener;
import com.ai.common.utils.JsonUtils;
import com.ai.spark.common.Usage;
import com.ai.spark.endPoint.chat.ChatHeader;
import com.ai.spark.endPoint.chat.Choice;
import com.ai.spark.endPoint.chat.req.ChatRequest;
import com.ai.spark.endPoint.chat.resp.ChatResponse;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.WebSocket;
import okio.ByteString;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import static com.ai.common.utils.ValidationUtils.ensureNotNull;
/**
* 星火大模型流式返回监听器
*/
@Slf4j
@Data
public abstract class ChatListener extends BaseListener {
private ChatRequest chatRequest;
public ChatListener(ChatRequest chatRequest) {
this.chatRequest = ensureNotNull(chatRequest, "chatRequest");
}
/**
* WebSocket服务发生异常的回调,可以覆盖重写。
* 默认抛出异常
*
* @param t 异常
* @param response 返回值
*/
public void onWebSocketError(Throwable t, Response response) {
log.error("调用星火模型时,WebSocket发生异常:{}", response);
t.printStackTrace();
}
/**
* 星火大模型发生异常
*
* @param chatResponse 大模型返回值
*/
public abstract void onChatError(ChatResponse chatResponse);
/**
* 星火大模型正常返回信息
*
* @param chatResponse 大模型返回值
*/
public abstract void onChatOutput(ChatResponse chatResponse);
/**
* 星火大模型返回信息结束回调
*/
public abstract void onChatEnd();
@Override
public void onClosed(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
super.onClosed(webSocket, code, reason);
}
@Override
public void onClosing(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
super.onClosing(webSocket, code, reason);
}
@Override
public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, @Nullable Response response) {
webSocket.close(1000, "");
this.onWebSocketError(t, response);
}
@Override
public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
ChatResponse chatResponse = JsonUtils.fromJson(text, ChatResponse.class);
if (ChatHeader.Code.SUCCESS.getValue() != chatResponse.getChatHeader().getCode()) {
log.warn("调用星火模型发生错误,错误码为:{},请求的sid为:{}", chatResponse.getChatHeader().getCode(), chatResponse.getChatHeader().getSid());
webSocket.close(1000, "星火模型调用异常");
this.onChatError(chatResponse);
return;
}
this.onChatOutput(chatResponse);
if (Choice.Status.END.getValue() == chatResponse.getChatHeader().getStatus()) {
// 可以关闭连接,释放资源
webSocket.close(1000, "星火模型返回结束");
Usage usage = chatResponse.getChatPayload().getUsage();
this.onChatEnd();
}
}
@Override
public void onMessage(@NotNull WebSocket webSocket, @NotNull ByteString bytes) {
super.onMessage(webSocket, bytes);
}
@Override
public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
super.onOpen(webSocket, response);
webSocket.send(JsonUtils.toJson(chatRequest));
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/listener/DocumentChatListener.java
================================================
package com.ai.spark.achieve.defaults.listener;
import com.ai.common.utils.JsonUtils;
import com.ai.spark.endPoint.chat.req.DocumentChatRequest;
import com.ai.spark.endPoint.chat.resp.DocumentChatResponse;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.WebSocket;
import okio.ByteString;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import static com.ai.common.utils.ValidationUtils.ensureNotNull;
/**
* 星火大模型文档对话流式返回监听器
*/
@Slf4j
@Data
public abstract class DocumentChatListener extends BaseListener {
private DocumentChatRequest documentChatRequest;
/**
* 构造方法,传入大模型参数
*
* @param documentChatRequest 大模型参数
*/
public DocumentChatListener(DocumentChatRequest documentChatRequest) {
this.documentChatRequest = ensureNotNull(documentChatRequest, "documentChatRequest");
}
/**
* WebSocket服务发生异常的回调,可以覆盖重写。
* 默认抛出异常
*
* @param t 异常
* @param response 返回值
*/
public void onWebSocketError(Throwable t, Response response) {
log.error("调用星火模型时,WebSocket发生异常:{}", response);
t.printStackTrace();
}
/**
* 星火大模型发生异常
*
* @param documentChatResponse 大模型返回值
*/
public abstract void onChatError(DocumentChatResponse documentChatResponse);
/**
* 星火大模型正常返回信息
*
* @param documentChatResponse 大模型返回值
*/
public abstract void onChatOutput(DocumentChatResponse documentChatResponse);
/**
* 星火大模型返回信息结束回调
*/
public abstract void onChatEnd();
@Override
public void onClosed(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
super.onClosed(webSocket, code, reason);
}
@Override
public void onClosing(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
super.onClosing(webSocket, code, reason);
}
@Override
public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, @Nullable Response response) {
webSocket.close(1000, "");
this.onWebSocketError(t, response);
}
@Override
public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
DocumentChatResponse documentChatResponse = JsonUtils.fromJson(text, DocumentChatResponse.class);
if (DocumentChatResponse.Code.SUCCESS.getValue() != documentChatResponse.getCode()) {
log.warn("调用星火模型文档对话发生错误,错误码为:{},请求的sid为:{}", documentChatResponse.getCode(), documentChatResponse.getSid());
webSocket.close(1000, "星火模型调用异常");
this.onChatError(documentChatResponse);
return;
}
this.onChatOutput(documentChatResponse);
if (DocumentChatResponse.Status.END.getValue() == documentChatResponse.getStatus()) {
// 可以关闭连接,释放资源
webSocket.close(1000, "星火模型返回结束");
this.onChatEnd();
}
}
@Override
public void onMessage(@NotNull WebSocket webSocket, @NotNull ByteString bytes) {
super.onMessage(webSocket, bytes);
}
@Override
public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
super.onOpen(webSocket, response);
webSocket.send(JsonUtils.toJson(documentChatRequest));
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/listener/ImageUnderstandingListener.java
================================================
package com.ai.spark.achieve.defaults.listener;
import com.ai.common.utils.JsonUtils;
import com.ai.spark.endPoint.chat.ChatHeader;
import com.ai.spark.endPoint.chat.Choice;
import com.ai.spark.endPoint.chat.resp.DocumentChatResponse;
import com.ai.spark.endPoint.images.ImageHeader;
import com.ai.spark.endPoint.images.req.ImageUnderstandingRequest;
import com.ai.spark.endPoint.images.resp.ImageUnderstandingResponse;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.WebSocket;
import okio.ByteString;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import static com.ai.common.utils.ValidationUtils.ensureNotNull;
/**
* 图片理解对话监听器
*/
@Slf4j
@Data
public abstract class ImageUnderstandingListener extends BaseListener {
private ImageUnderstandingRequest imageUnderstandingRequest;
/**
* 构造方法,传入大模型参数
*
* @param imageUnderstandingRequest 大模型参数
*/
public ImageUnderstandingListener(ImageUnderstandingRequest imageUnderstandingRequest) {
this.imageUnderstandingRequest = ensureNotNull(imageUnderstandingRequest, "documentChatRequest");
}
/**
* WebSocket服务发生异常的回调,可以覆盖重写。
* 默认抛出异常
*
* @param t 异常
* @param response 返回值
*/
public void onWebSocketError(Throwable t, Response response) {
log.error("调用星火模型时,WebSocket发生异常:{}", response);
t.printStackTrace();
}
/**
* 星火大模型发生异常
*
* @param imageUnderstandingResponse 大模型返回值
*/
public abstract void onChatError(ImageUnderstandingResponse imageUnderstandingResponse);
/**
* 星火大模型正常返回信息
*
* @param imageUnderstandingResponse 大模型返回值
*/
public abstract void onChatOutput(ImageUnderstandingResponse imageUnderstandingResponse);
/**
* 星火大模型返回信息结束回调
*/
public abstract void onChatEnd();
@Override
public void onClosed(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
super.onClosed(webSocket, code, reason);
}
@Override
public void onClosing(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
super.onClosing(webSocket, code, reason);
}
@Override
public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, @Nullable Response response) {
webSocket.close(1000, "");
this.onWebSocketError(t, response);
}
@Override
public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
ImageUnderstandingResponse imageUnderstandingResponse = JsonUtils.fromJson(text, ImageUnderstandingResponse.class);
ImageHeader imageHeader = imageUnderstandingResponse.getImageHeader();
if (imageHeader.getCode() != ChatHeader.Code.SUCCESS.getValue()) {
log.warn("调用星火模型文档对话发生错误,错误码为:{},请求的sid为:{}", imageHeader.getCode(), imageHeader.getSid());
webSocket.close(1000, "星火模型调用异常");
this.onChatError(imageUnderstandingResponse);
return;
}
this.onChatOutput(imageUnderstandingResponse);
if (DocumentChatResponse.Status.END.getValue() == Choice.Status.END.getValue()) {
// 可以关闭连接,释放资源
webSocket.close(1000, "星火模型返回结束");
this.onChatEnd();
}
}
@Override
public void onMessage(@NotNull WebSocket webSocket, @NotNull ByteString bytes) {
super.onMessage(webSocket, bytes);
}
@Override
public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
super.onOpen(webSocket, response);
webSocket.send(JsonUtils.toJson(imageUnderstandingRequest));
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/session/DefaultAggregationSession.java
================================================
package com.ai.spark.achieve.defaults.session;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.standard.session.*;
import static com.ai.common.utils.ValidationUtils.ensureNotNull;
/**
* 聚合各个场景的session
*/
public class DefaultAggregationSession implements AggregationSession {
private Configuration configuration;
private volatile ChatSession chatSession;
private volatile DocumentSession documentSession;
private volatile EmbeddingSession embeddingSession;
private volatile ImageSession imageSession;
private volatile AudioSession audioSession;
public DefaultAggregationSession(Configuration configuration) {
this.configuration = ensureNotNull(configuration, "configuration");
}
@Override
public ChatSession getChatSession() {
if (chatSession == null) {
synchronized (this) {
if (chatSession == null) {
chatSession = new DefaultChatSession(configuration);
}
}
}
return chatSession;
}
@Override
public DocumentSession getDocumentSession() {
if (documentSession == null) {
synchronized (this) {
if (documentSession == null) {
documentSession = new DefaultDocumentSession(configuration);
}
}
}
return documentSession;
}
@Override
public EmbeddingSession getEmbeddingSession() {
if (embeddingSession == null) {
synchronized (this) {
if (embeddingSession == null) {
embeddingSession = new DefaultEmbeddingSession(configuration);
}
}
}
return embeddingSession;
}
@Override
public ImageSession getImageSession() {
if (imageSession == null) {
synchronized (this) {
if (imageSession == null) {
imageSession = new DefaultImageSession(configuration);
}
}
}
return imageSession;
}
@Override
public AudioSession getAudioSession() {
if (audioSession == null) {
synchronized (this) {
if (audioSession == null) {
audioSession = new DefaultAudioSession(configuration);
}
}
}
return audioSession;
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/session/DefaultAudioSession.java
================================================
package com.ai.spark.achieve.defaults.session;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.standard.session.AudioSession;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import static com.ai.common.utils.ValidationUtils.ensureNotNull;
@Data
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class DefaultAudioSession extends Session implements AudioSession {
public DefaultAudioSession(Configuration configuration) {
this.setConfiguration(ensureNotNull(configuration, "configuration"));
this.setSparkApiServer(ensureNotNull(configuration.getSparkApiServer(), "sparkApiServer"));
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/session/DefaultChatSession.java
================================================
package com.ai.spark.achieve.defaults.session;
import cn.hutool.core.date.DateUtil;
import com.ai.spark.achieve.ApiData;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.defaults.listener.ChatListener;
import com.ai.spark.achieve.defaults.listener.DocumentChatListener;
import com.ai.spark.achieve.standard.session.ChatSession;
import com.ai.spark.common.SparkApiUrl;
import com.ai.spark.common.utils.AuthUtils;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.SneakyThrows;
import lombok.ToString;
import okhttp3.Request;
import okhttp3.WebSocket;
import static com.ai.common.utils.ValidationUtils.ensureNotNull;
import static com.ai.spark.common.SparkApiUrl.DOCUMENT_CHAT;
@Data
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class DefaultChatSession extends Session implements ChatSession {
public DefaultChatSession(Configuration configuration) {
this.setConfiguration(ensureNotNull(configuration, "configuration"));
this.setSparkApiServer(ensureNotNull(configuration.getSparkApiServer(), "sparkApiServer"));
}
@Override
@SneakyThrows
public WebSocket chat(T chatListener) {
// 默认情况下根据apiData获取策略得到创建时设置的参数
ApiData apiData = this.getConfiguration().getSystemApiData();
return this.chat(apiData.getApiKey(), apiData.getApiSecret(), chatListener);
}
@Override
@SneakyThrows
public WebSocket chat(String apiKey, String apiSecret, T chatListener) {
// 获取到对应访问的domain,根据domain获取对应的请求地址
String domain = chatListener.getChatRequest().getChatParameter().getChat().getDomain();
// 生成请求的URL
String url = AuthUtils.replaceAllHttp(
AuthUtils.getAuthUrl(AuthUtils.RequestMethod.GET.getMethod(), SparkApiUrl.getUrl(domain), apiKey, apiSecret)
);
// 发起请求返回结果
return this.getConfiguration().getOkHttpClient().newWebSocket(new Request.Builder().url(url).build(), chatListener);
}
@Override
public WebSocket documentChat(T documentChatListener) {
ApiData apiData = this.getConfiguration().getSystemApiData();
return documentChat(apiData.getAppId(), apiData.getApiSecret(), documentChatListener);
}
@Override
public WebSocket documentChat(String appId, String apiSecret, T documentChatListener) {
// 得当当前时间戳,按秒计算
long ts = DateUtil.currentSeconds();
// 进行签名设置
String url = SparkApiUrl.getUrl(DOCUMENT_CHAT) + "?"
+ "appId=" + appId
+ "×tamp=" + ts
+ "&signature=" + AuthUtils.getSignature(appId, apiSecret, ts);
// 发起请求返回结果
return this.getConfiguration().getOkHttpClient().newWebSocket(new Request.Builder().url(url).build(), documentChatListener);
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/session/DefaultDocumentSession.java
================================================
package com.ai.spark.achieve.defaults.session;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.util.StrUtil;
import com.ai.spark.achieve.ApiData;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.standard.session.DocumentSession;
import com.ai.spark.common.SparkApiUrl;
import com.ai.spark.common.utils.AuthUtils;
import com.ai.spark.endPoint.document.req.FileUploadRequest;
import com.ai.spark.endPoint.document.resp.DocumentSummaryResponse;
import com.ai.spark.endPoint.document.resp.FileUploadResponse;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import okhttp3.MediaType;
import okhttp3.MultipartBody;
import okhttp3.RequestBody;
import java.util.HashMap;
import java.util.Map;
import static com.ai.common.utils.ValidationUtils.ensureNotNull;
@Data
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class DefaultDocumentSession extends Session implements DocumentSession {
public DefaultDocumentSession(Configuration configuration) {
this.setConfiguration(ensureNotNull(configuration, "configuration"));
this.setSparkApiServer(ensureNotNull(configuration.getSparkApiServer(), "sparkApiServer"));
}
@Override
public FileUploadResponse fileUpload(FileUploadRequest fileUploadRequest) {
ApiData apiData = this.getConfiguration().getSystemApiData();
return this.fileUpload(apiData.getAppId(), apiData.getApiSecret(), fileUploadRequest);
}
@Override
public FileUploadResponse fileUpload(String appId, String apiSecret, FileUploadRequest fileUploadRequest) {
// 得到当前时间戳,按秒计算
long ts = DateUtil.currentSeconds();
// 设置文件
RequestBody fileBody = RequestBody.create(MediaType.parse("multipart/form-data"), fileUploadRequest.getFile());
MultipartBody.Part multipartBody = MultipartBody.Part.createFormData("file", fileUploadRequest.getFile().getName(), fileBody);
// 设置其余参数
Map requestBodyMap = new HashMap<>();
if (StrUtil.isNotBlank(fileUploadRequest.getUrl())) {
requestBodyMap.put(FileUploadRequest.Fields.url, RequestBody.create(MediaType.parse("multipart/form-data"), fileUploadRequest.getUrl()));
}
if (StrUtil.isNotBlank(fileUploadRequest.getFileName())) {
requestBodyMap.put(FileUploadRequest.Fields.fileName, RequestBody.create(MediaType.parse("multipart/form-data"), fileUploadRequest.getFileName()));
}
if (StrUtil.isNotBlank(fileUploadRequest.getFileType())) {
requestBodyMap.put(FileUploadRequest.Fields.fileType, RequestBody.create(MediaType.parse("multipart/form-data"), fileUploadRequest.getFileType()));
}
if (StrUtil.isNotBlank(fileUploadRequest.getCallbackUrl())) {
requestBodyMap.put(FileUploadRequest.Fields.callbackUrl, RequestBody.create(MediaType.parse("multipart/form-data"), fileUploadRequest.getCallbackUrl()));
}
// 发起请求返回结果
return this.getSparkApiServer().fileUpload(appId, String.valueOf(ts), AuthUtils.getSignature(appId, apiSecret, ts), multipartBody, requestBodyMap).blockingGet();
}
@Override
public DocumentSummaryResponse documentSummaryStart(String fileId) {
ApiData apiData = this.getConfiguration().getSystemApiData();
return this.documentSummaryStart(apiData.getAppId(), apiData.getApiSecret(), fileId);
}
@Override
public DocumentSummaryResponse documentSummaryStart(String appId, String apiSecret, String fileId) {
return this.documentSummary(SparkApiUrl.ApiUrl.documentSummaryStart.getUrl(), appId, apiSecret, fileId);
}
@Override
public DocumentSummaryResponse documentSummaryQuery(String fileId) {
ApiData apiData = this.getConfiguration().getSystemApiData();
return this.documentSummaryQuery(apiData.getAppId(), apiData.getApiSecret(), fileId);
}
@Override
public DocumentSummaryResponse documentSummaryQuery(String appId, String apiSecret, String fileId) {
return this.documentSummary(SparkApiUrl.ApiUrl.documentSummaryQuery.getUrl(), appId, apiSecret, fileId);
}
/**
* 文档总结底层都依赖这个方法
*
* @param url 请求的URL
* @param appId 用户的AppId
* @param apiSecret 用户的ApiSecret
* @param fileId 文件ID
* @return 请求结果
*/
private DocumentSummaryResponse documentSummary(String url, String appId, String apiSecret, String fileId) {
long ts = DateUtil.currentSeconds();
return this.getSparkApiServer().documentSummary(url, appId, String.valueOf(ts), AuthUtils.getSignature(appId, apiSecret, ts), RequestBody.create(MediaType.parse("multipart/form-data"), fileId)).blockingGet();
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/session/DefaultEmbeddingSession.java
================================================
package com.ai.spark.achieve.defaults.session;
import cn.hutool.http.ContentType;
import com.ai.common.utils.JsonUtils;
import com.ai.spark.achieve.ApiData;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.standard.session.EmbeddingSession;
import com.ai.spark.common.SparkApiUrl;
import com.ai.spark.common.utils.AuthUtils;
import com.ai.spark.endPoint.embedding.req.EmbeddingRequest;
import com.ai.spark.endPoint.embedding.resp.EmbeddingResponse;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.SneakyThrows;
import lombok.ToString;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import static com.ai.common.utils.ValidationUtils.ensureNotNull;
@Data
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class DefaultEmbeddingSession extends Session implements EmbeddingSession {
public DefaultEmbeddingSession(Configuration configuration) {
this.setConfiguration(ensureNotNull(configuration, "configuration"));
this.setSparkApiServer(ensureNotNull(configuration.getSparkApiServer(), "sparkApiServer"));
}
@Override
public EmbeddingResponse embed(EmbeddingRequest embeddingRequest) {
ApiData apiData = this.getConfiguration().getSystemApiData();
return this.embed(apiData.getApiKey(), apiData.getApiSecret(), embeddingRequest);
}
@Override
@SneakyThrows
public EmbeddingResponse embed(String apiKey, String apiSecret, EmbeddingRequest embeddingRequest) {
// 鉴权,得到请求路径
String authUrl = AuthUtils.getAuthUrl(AuthUtils.RequestMethod.POST.getMethod(), SparkApiUrl.ApiUrl.embeddingq.getUrl(), apiKey, apiSecret);
// 创建请求,设置请求URL和json数据
Request request = new Request.Builder().url(authUrl).post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), JsonUtils.toJson(embeddingRequest))).build();
// 发起请求,获取返回的json字符串
String response = this.getConfiguration().getOkHttpClient().newCall(request).execute().body().string();
// 将json映射到对象上
return JsonUtils.fromJson(response, EmbeddingResponse.class);
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/session/DefaultImageSession.java
================================================
package com.ai.spark.achieve.defaults.session;
import cn.hutool.http.ContentType;
import com.ai.common.utils.JsonUtils;
import com.ai.spark.achieve.ApiData;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.defaults.listener.ImageUnderstandingListener;
import com.ai.spark.achieve.standard.session.ImageSession;
import com.ai.spark.common.SparkApiUrl;
import com.ai.spark.common.utils.AuthUtils;
import com.ai.spark.endPoint.images.req.ImageCreateRequest;
import com.ai.spark.endPoint.images.req.ImageUnderstandingRequest;
import com.ai.spark.endPoint.images.resp.ImageCreateResponse;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.SneakyThrows;
import lombok.ToString;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.WebSocket;
import static com.ai.common.utils.ValidationUtils.ensureNotNull;
@Data
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class DefaultImageSession extends Session implements ImageSession {
public DefaultImageSession(Configuration configuration) {
this.setConfiguration(ensureNotNull(configuration, "configuration"));
this.setSparkApiServer(ensureNotNull(configuration.getSparkApiServer(), "sparkApiServer"));
}
@Override
@SneakyThrows
public ImageCreateResponse imageCreate(ImageCreateRequest imageCreateRequest) {
ApiData apiData = this.getConfiguration().getSystemApiData();
return this.imageCreate(apiData.getApiKey(), apiData.getApiSecret(), imageCreateRequest);
}
@Override
@SneakyThrows
public ImageCreateResponse imageCreate(String apiKey, String apiSecret, ImageCreateRequest imageCreateRequest) {
// 鉴权,得到请求路径
String authUrl = AuthUtils.getAuthUrl(AuthUtils.RequestMethod.POST.getMethod(), SparkApiUrl.ApiUrl.imageCreate.getUrl(), apiKey, apiSecret);
// 创建请求,设置请求URL和json数据
Request request = new Request.Builder().url(authUrl).post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), JsonUtils.toJson(imageCreateRequest))).build();
// 发起请求,获取返回的json字符串
String response = this.getConfiguration().getOkHttpClient().newCall(request).execute().body().string();
// 将json映射到对象上
return JsonUtils.fromJson(response, ImageCreateResponse.class);
}
@Override
public WebSocket imageUnderstanding(ImageUnderstandingRequest imageUnderstandingRequest, ImageUnderstandingListener imageUnderstandingListener) {
ApiData apiData = this.getConfiguration().getSystemApiData();
return imageUnderstanding(apiData.getApiKey(), apiData.getApiSecret(), imageUnderstandingRequest, imageUnderstandingListener);
}
@Override
@SneakyThrows
public WebSocket imageUnderstanding(String apiKey, String apiSecret, ImageUnderstandingRequest imageUnderstandingRequest, ImageUnderstandingListener imageUnderstandingListener) {
// 生成请求的URL
String url = AuthUtils.replaceAllHttp(
AuthUtils.getAuthUrl(AuthUtils.RequestMethod.GET.getMethod(), SparkApiUrl.ApiUrl.imageUnderstanding.getUrl(), apiKey, apiSecret)
);
// 发起请求返回结果
return this.getConfiguration().getOkHttpClient().newWebSocket(new Request.Builder().url(url).build(), imageUnderstandingListener);
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/defaults/session/Session.java
================================================
package com.ai.spark.achieve.defaults.session;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.standard.api.SparkApiServer;
import lombok.Data;
@Data
public class Session {
private Configuration configuration;
private SparkApiServer sparkApiServer;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/standard/api/SparkApiServer.java
================================================
package com.ai.spark.achieve.standard.api;
import com.ai.spark.endPoint.document.resp.DocumentSummaryResponse;
import com.ai.spark.endPoint.document.resp.FileUploadResponse;
import io.reactivex.Single;
import okhttp3.MultipartBody;
import okhttp3.RequestBody;
import retrofit2.http.*;
import java.util.Map;
import static com.ai.spark.common.Constants.*;
import static com.ai.spark.common.SparkApiUrl.FILE_UPLOAD_API_URL;
/**
* @Description: 讯飞星火 API接口
*/
public interface SparkApiServer {
/**
* 文件上传接口
*
* @param appId 用户appId
* @param timestamp 当前时间戳(秒)
* @param signature 生成的签名
* @param file 需要上传的文件
* @param requestBodyMap 其他字段
* @return 文件上传返回参数
*/
@Multipart
@POST(FILE_UPLOAD_API_URL)
Single fileUpload(@Header(APP_ID) String appId, @Header(TIMESTAMP) String timestamp, @Header(SIGNATURE) String signature, @Part MultipartBody.Part file, @PartMap Map requestBodyMap);
/**
* 文档总结,文档总结结果查询两个接口二合一,参数都一样,请求路径不同
*
* @param url 请求地址
* @param appId 用户appId
* @param timestamp 当前时间戳(秒)
* @param signature 生成的签名
* @param fileId 文件ID
* @return 请求结果信息
*/
@Multipart
@POST
Single documentSummary(@Url String url, @Header(APP_ID) String appId, @Header(TIMESTAMP) String timestamp, @Header(SIGNATURE) String signature, @Part("fileId") RequestBody fileId);
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/standard/session/AggregationSession.java
================================================
package com.ai.spark.achieve.standard.session;
public interface AggregationSession {
ChatSession getChatSession();
DocumentSession getDocumentSession();
EmbeddingSession getEmbeddingSession();
ImageSession getImageSession();
AudioSession getAudioSession();
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/standard/session/AudioSession.java
================================================
package com.ai.spark.achieve.standard.session;
public interface AudioSession {
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/standard/session/ChatSession.java
================================================
package com.ai.spark.achieve.standard.session;
import com.ai.spark.achieve.defaults.listener.ChatListener;
import com.ai.spark.achieve.defaults.listener.DocumentChatListener;
import okhttp3.WebSocket;
/**
* 对话场景下的接口
*/
public interface ChatSession {
/**
* 聊天接口
*
* @param chatListener 对话监听器
* @return 返回信息
*/
WebSocket chat(T chatListener);
/**
* 聊天接口,自定义所使用的到的ApiKey和ApiSecret
*
* @param apiKey 用户的ApiKey
* @param apiSecret 用户的ApiSecret
* @param chatListener 对话监听器
* @return 返回信息
*/
WebSocket chat(String apiKey, String apiSecret, T chatListener);
/**
* 文档聊天接口
*
* @param documentChatListener 对话监听器
* @return 返回信息
*/
WebSocket documentChat(T documentChatListener);
/**
* 聊天接口,自定义所使用的到的AppId和ApiSecret
*
* @param appId 用户的AppId
* @param apiSecret 用户的ApiSecret
* @param documentChatListener 对话监听器
* @return 返回信息
*/
WebSocket documentChat(String appId, String apiSecret, T documentChatListener);
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/standard/session/DocumentSession.java
================================================
package com.ai.spark.achieve.standard.session;
import com.ai.spark.endPoint.document.req.FileUploadRequest;
import com.ai.spark.endPoint.document.resp.DocumentSummaryResponse;
import com.ai.spark.endPoint.document.resp.FileUploadResponse;
/**
* 文档场景下的接口
*/
public interface DocumentSession {
/**
* 文件上传
*
* @param fileUploadRequest 上传的信息
* @return 返回信息
*/
FileUploadResponse fileUpload(FileUploadRequest fileUploadRequest);
/**
* 文件上传
*
* @param appId 用户的AppId
* @param apiSecret 用户的ApiSecret
* @param fileUploadRequest 上传的信息
* @return 返回信息
*/
FileUploadResponse fileUpload(String appId, String apiSecret, FileUploadRequest fileUploadRequest);
/**
* 发起文档总结接口
*
* @param fileId 文件ID
* @return 返回信息
*/
DocumentSummaryResponse documentSummaryStart(String fileId);
/**
* 发起文档总结接口
*
* @param appId 用户的AppId
* @param apiSecret 用户的ApiSecret
* @param fileId 文件ID
* @return 返回信息
*/
DocumentSummaryResponse documentSummaryStart(String appId, String apiSecret, String fileId);
/**
* 查询文档总结结果
*
* @param fileId 文件ID
* @return 返回信息
*/
DocumentSummaryResponse documentSummaryQuery(String fileId);
/**
* 查询文档总结结果
*
* @param appId 用户的AppId
* @param apiSecret 用户的ApiSecret
* @param fileId 文件ID
* @return 返回信息
*/
DocumentSummaryResponse documentSummaryQuery(String appId, String apiSecret, String fileId);
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/standard/session/EmbeddingSession.java
================================================
package com.ai.spark.achieve.standard.session;
import com.ai.spark.endPoint.embedding.req.EmbeddingRequest;
import com.ai.spark.endPoint.embedding.resp.EmbeddingResponse;
/**
* 文本嵌入场景下的接口
*/
public interface EmbeddingSession {
/**
* 文本嵌入
*
* @param embeddingRequest 请求参数
* @return 请求结果
*/
EmbeddingResponse embed(EmbeddingRequest embeddingRequest);
/**
* 文本嵌入
*
* @param apiKey 用户的ApiKey
* @param apiSecret 用户的ApiSecret
* @param embeddingRequest 请求参数
* @return 请求结果
*/
EmbeddingResponse embed(String apiKey, String apiSecret, EmbeddingRequest embeddingRequest);
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/achieve/standard/session/ImageSession.java
================================================
package com.ai.spark.achieve.standard.session;
import com.ai.spark.achieve.defaults.listener.ImageUnderstandingListener;
import com.ai.spark.endPoint.images.req.ImageCreateRequest;
import com.ai.spark.endPoint.images.req.ImageUnderstandingRequest;
import com.ai.spark.endPoint.images.resp.ImageCreateResponse;
import okhttp3.WebSocket;
/**
* 图片生成场景下的接口
*/
public interface ImageSession {
/**
* 图片创作接口,使用系统默认的ApiData
*
* @param imageCreateRequest 请求参数
* @return 请求结果
*/
ImageCreateResponse imageCreate(ImageCreateRequest imageCreateRequest);
/**
* 图片创作接口,使用自定义的ApiData
*
* @param apiKey 用户的ApiKey
* @param apiSecret 用户的ApiSecret
* @param imageCreateRequest 请求参数
* @return 请求结果
*/
ImageCreateResponse imageCreate(String apiKey, String apiSecret, ImageCreateRequest imageCreateRequest);
/**
* 图片理解
*
* @param imageUnderstandingRequest 请求参数
* @return 请求结果
*/
public WebSocket imageUnderstanding(ImageUnderstandingRequest imageUnderstandingRequest, ImageUnderstandingListener imageUnderstandingListener);
/**
* 图片理解
*
* @param apiKey 用户的ApiKey
* @param apiSecret 用户的ApiSecret
* @param imageUnderstandingRequest 请求参数
* @return 请求结果
*/
public WebSocket imageUnderstanding(String apiKey, String apiSecret, ImageUnderstandingRequest imageUnderstandingRequest, ImageUnderstandingListener imageUnderstandingListener);
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/common/Constants.java
================================================
package com.ai.spark.common;
public class Constants {
public static final String APP_ID = "appId";
public static final String TIMESTAMP = "timestamp";
public static final String SIGNATURE = "signature";
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/common/SparkApiUrl.java
================================================
package com.ai.spark.common;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.Map;
/**
* 记录各个API的请求URL
*/
@Slf4j
public class SparkApiUrl {
// 星火模型对话链接
public final static String GENERAL_V1 = "general";
public final static String GENERAL_V2 = "generalv2";
public final static String GENERAL_V3 = "generalv3";
public final static String SPARK_API_HOST_WS_V1_1_URL = "http://spark-api.xf-yun.com/v1.1/chat";
public final static String SPARK_API_HOST_WSS_V1_1_URL = "https://spark-api.xf-yun.com/v1.1/chat";
public final static String SPARK_API_HOST_WS_V2_1_URL = "http://spark-api.xf-yun.com/v2.1/chat";
public final static String SPARK_API_HOST_WSS_V2_1_URL = "https://spark-api.xf-yun.com/v2.1/chat";
public final static String SPARK_API_HOST_WS_V3_1_URL = "http://spark-api.xf-yun.com/v3.1/chat";
public final static String SPARK_API_HOST_WSS_V3_1_URL = "https://spark-api.xf-yun.com/v3.1/chat";
// 文档对话文件上传链接
public final static String FILE_UPLOAD = "fileUpload";
public final static String FILE_UPLOAD_API_URL = "https://chatdoc.xfyun.cn/openapi/fileUpload";
// 文档对话链接
public final static String DOCUMENT_CHAT = "documentChat";
public final static String DOCUMENT_CHAT_API_URL = "wss://chatdoc.xfyun.cn/openapi/chat";
// 文档总结链接
public final static String DOCUMENT_SUMMARY_START = "documentSummaryStart";
public final static String DOCUMENT_SUMMARY_START_API_URL = "https://chatdoc.xfyun.cn/openapi/startSummary";
// 查询文档总结结果链接
public final static String DOCUMENT_SUMMARY_QUERY = "documentSummaryStart";
public final static String DOCUMENT_SUMMARY_QUERY_API_URL = "https://chatdoc.xfyun.cn/openapi/fileSummary";
// 文本嵌入接口
public final static String EMBEDDING_P = "Embeddingp";
public final static String EMBEDDING_P_API_URL = "https://cn-huabei-1.xf-yun.com/v1/private/sa8a05c27";
public final static String EMBEDDING_Q = "Embeddingq";
public final static String EMBEDDING_Q_API_URL = "https://cn-huabei-1.xf-yun.com/v1/private/s50d55a16";
// 图片生成接口
public final static String IMAGE_CREATE = "imageCreate";
public final static String IMAGE_CREATE_API_URL = "https://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti";
// 图片理解接口
public final static String IMAGE_UNDERSANDING = "imageUnderstanding";
public final static String IMAGE_UNDERSANDING_API_URL = "https://spark-api.cn-huabei-1.xf-yun.com/v2.1/image";
// 超拟人合成协议接口
public final static String HYPERMIMETIC_SYNTHESIS = "hypermimeticSynthesis";
public final static String HYPERMIMETIC_SYNTHESIS_API_URL = "wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/medd90fec";
public final static Map urlMap = new HashMap<>();
static {
urlMap.put(GENERAL_V1, SPARK_API_HOST_WSS_V1_1_URL);
urlMap.put(GENERAL_V2, SPARK_API_HOST_WSS_V2_1_URL);
urlMap.put(GENERAL_V3, SPARK_API_HOST_WSS_V3_1_URL);
urlMap.put(FILE_UPLOAD, FILE_UPLOAD_API_URL);
urlMap.put(DOCUMENT_CHAT, DOCUMENT_CHAT_API_URL);
urlMap.put(DOCUMENT_SUMMARY_START, DOCUMENT_SUMMARY_START_API_URL);
urlMap.put(DOCUMENT_SUMMARY_QUERY, DOCUMENT_SUMMARY_QUERY_API_URL);
urlMap.put(EMBEDDING_P, EMBEDDING_P_API_URL);
urlMap.put(EMBEDDING_Q, EMBEDDING_Q_API_URL);
urlMap.put(IMAGE_CREATE, IMAGE_CREATE_API_URL);
urlMap.put(IMAGE_UNDERSANDING, IMAGE_UNDERSANDING_API_URL);
urlMap.put(HYPERMIMETIC_SYNTHESIS, HYPERMIMETIC_SYNTHESIS_API_URL);
}
public static String getUrl(String key) {
if (!urlMap.containsKey(key)) {
log.error("No corresponding URL path found for {}", key);
return null;
}
return urlMap.get(key);
}
@Getter
@AllArgsConstructor
public enum ApiUrl {
general(SPARK_API_HOST_WSS_V1_1_URL),
generalV2(SPARK_API_HOST_WSS_V2_1_URL),
generalV3(SPARK_API_HOST_WSS_V3_1_URL),
fileUpload(FILE_UPLOAD_API_URL),
documentChat(DOCUMENT_CHAT_API_URL),
documentSummaryStart(DOCUMENT_SUMMARY_START_API_URL),
documentSummaryQuery(DOCUMENT_SUMMARY_QUERY_API_URL),
embeddingp(EMBEDDING_P_API_URL),
embeddingq(EMBEDDING_Q_API_URL),
imageCreate(IMAGE_CREATE_API_URL),
imageUnderstanding(IMAGE_UNDERSANDING_API_URL),
hypermimeticSynthesis(HYPERMIMETIC_SYNTHESIS_API_URL);
private String url;
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/common/Usage.java
================================================
package com.ai.spark.common;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Usage {
@JsonProperty("text")
private UsageText usageText;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/common/UsageText.java
================================================
package com.ai.spark.common;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class UsageText {
/**
* 保留字段,可忽略
*/
@JsonProperty("question_tokens")
private Integer questionTokens;
/**
* 包含历史问题的总tokens大小
*/
@JsonProperty("prompt_tokens")
private Integer promptTokens;
/**
* 回答的tokens大小
*/
@JsonProperty("completion_tokens")
private Integer completionTokens;
/**
* prompt_tokens和completion_tokens的和
*/
@JsonProperty("total_tokens")
private Integer totalTokens;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/common/utils/AuthUtils.java
================================================
package com.ai.spark.common.utils;
import lombok.AllArgsConstructor;
import lombok.Getter;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SignatureException;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Base64;
import java.util.Locale;
/**
* 鉴权工具类
*/
public class AuthUtils {
/**
* 日期格式化
*/
public final static DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
public final static String preStr = "host: %s\n" +
"date: %s\n" +
"%s %s HTTP/1.1";
private static final char[] MD5_TABLE = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
public static String replaceAllHttp(String authUrl) {
return authUrl.replaceAll("http://", "ws://").replaceAll("https://", "wss://");
}
/**
* 鉴权方法,适用于对话接口
*
* @param requestMethod 请求方式
* @param hostUrl 地址
* @param apiKey apikey
* @param apiSecret apiSecret
* @return 鉴权信息
*/
public static String getAuthUrl(String requestMethod, String hostUrl, String apiKey, String apiSecret) throws InvalidKeyException, NoSuchAlgorithmException, URISyntaxException, UnsupportedEncodingException {
URI uri = new URI(hostUrl);
String date = ZonedDateTime.now(ZoneId.of("GMT")).format(dateTimeFormatter);
// SHA256加密
Mac mac = Mac.getInstance("hmacsha256");
mac.init(new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256"));
byte[] hexDigits = mac.doFinal(String.format(preStr, uri.getHost(), date, requestMethod, uri.getPath()).getBytes(StandardCharsets.UTF_8));
String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", Base64.getEncoder().encodeToString(hexDigits));
// 拼接地址
return new StringBuilder(hostUrl)
.append("?authorization=").append(Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8)))
.append("&date=").append(URLEncoder.encode(date, StandardCharsets.UTF_8.name()))
.append("&host=").append(uri.getHost())
.toString();
}
/**
* 获取签名,适用于文档问答
*
* @param appId 签名的key
* @param secret 签名秘钥
* @return 返回签名
*/
public static String getSignature(String appId, String secret, long ts) {
try {
String auth = md5(appId + ts);
return hmacSHA1Encrypt(auth, secret);
} catch (SignatureException e) {
return null;
}
}
/**
* sha1加密
*
* @param encryptText 加密文本
* @param encryptKey 加密键
* @return 加密
*/
private static String hmacSHA1Encrypt(String encryptText, String encryptKey)
throws SignatureException {
byte[] rawHmac;
try {
byte[] data = encryptKey.getBytes(StandardCharsets.UTF_8);
SecretKeySpec secretKey = new SecretKeySpec(data, "HmacSHA1");
Mac mac = Mac.getInstance("HmacSHA1");
mac.init(secretKey);
byte[] text = encryptText.getBytes(StandardCharsets.UTF_8);
rawHmac = mac.doFinal(text);
} catch (InvalidKeyException e) {
throw new SignatureException("InvalidKeyException:" + e.getMessage());
} catch (NoSuchAlgorithmException e) {
throw new SignatureException(
"NoSuchAlgorithmException:" + e.getMessage()
);
}
return Base64.getEncoder().encodeToString(rawHmac);
}
private static String md5(String cipherText) {
try {
byte[] data = cipherText.getBytes();
// 信息摘要是安全的单向哈希函数,它接收任意大小的数据,并输出固定长度的哈希值。
MessageDigest mdInst = MessageDigest.getInstance("MD5");
// MessageDigest对象通过使用 update方法处理数据, 使用指定的byte数组更新摘要
mdInst.update(data);
// 摘要更新之后,通过调用digest()执行哈希计算,获得密文
byte[] md = mdInst.digest();
// 把密文转换成十六进制的字符串形式
int j = md.length;
char[] str = new char[j * 2];
int k = 0;
for (byte byte0 : md) { // i = 0
str[k++] = MD5_TABLE[byte0 >>> 4 & 0xf]; // 5
str[k++] = MD5_TABLE[byte0 & 0xf]; // F
}
// 返回经过加密后的字符串
return new String(str);
} catch (Exception e) {
return null;
}
}
@Getter
@AllArgsConstructor
public enum RequestMethod {
GET("GET"),
POST("POST");
private String method;
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/audio/Audio.java
================================================
package com.ai.spark.endPoint.audio;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Audio {
/**
* ⾳频编码
*/
private String encoding;
/**
* ⾳频编码
*/
@JsonProperty("sample_rate")
private Integer sampleRate;
/**
* 声道数
*/
private Integer channels;
/**
* 位深
*/
@JsonProperty("bit_depth")
private Integer bitDepth;
/**
* 帧⼤⼩
*/
@JsonProperty("frame_size")
private Integer frameSize;
/**
* 数据状态
* 0:开始, 1:开始, 2:结束
*/
private Integer status;
/**
* 数据序号
* 最⼩值:0, 最⼤值:9999999
*/
private Integer seq;
/**
* ⾳频数据
* 最⼩尺⼨:0B, 最⼤尺⼨:10485760B
*/
private String audio;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/audio/AudioHeader.java
================================================
package com.ai.spark.endPoint.audio;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class AudioHeader {
/**
* 在平台申请的appid信息,必传
*/
@JsonProperty("app_id")
private String appId;
/**
* 请求⽤户服务返回的uid,⽤户及设备级别个性化功能依赖此参数
*/
private String uid;
/**
* 请求⽅确保唯⼀的设备标志,设备级别个性化功能依赖此参数
*/
private String did;
/**
* 设备imei信息
*/
private String imei;
/**
* 设备imsi信息
*/
private String imsi;
/**
* 设备mac信息
*/
private String mac;
/**
* ⽹络类型,可选值为wifi、2G、3G、4G、5G
*/
@JsonProperty("net_type")
private String netType;
/**
* 运营商信息,可选值为CMCC、CUCC、CTCC、other
*/
@JsonProperty("net_isp")
private String netIsp;
/**
* 客户端请求的会话唯⼀标识
*/
@JsonProperty("request_id")
private String requestId;
/**
* 个性化资源ID
*/
@JsonProperty("res_id")
private String resId;
/**
* 请求状态,可选值为:0-开始、1-继续、2-结束
*/
private Integer status;
// 以下是请求返回时所需参数
/**
* 错误码,0表示正常,非0表示出错;
*/
private Integer code;
/**
* 会话是否成功的描述信息
*/
private String message;
/**
* 会话的唯一id,用于讯飞技术人员查询服务端会话日志使用,出现调用错误时建议留存该字段
*/
private String sid;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/audio/AudioParameter.java
================================================
package com.ai.spark.endPoint.audio;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class AudioParameter {
private Oral oral;
private Tts tts;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/audio/AudioPayload.java
================================================
package com.ai.spark.endPoint.audio;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class AudioPayload {
/**
* 待合成⽂本
*/
private AudioText text;
/**
* ⽤户原始输⼊,场景化合成开启时必传,不开启为⾮必传
*/
@JsonProperty("user_text")
private AudioText userText;
/**
* 响应数据块
*/
private Audio audio;
/**
* 响应数据块
*/
private Pybuf pybuf;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/audio/AudioText.java
================================================
package com.ai.spark.endPoint.audio;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class AudioText {
/**
* ⽂本编码
*/
private String encoding;
/**
* ⽂本压缩格式
*/
private String compress;
/**
* ⽂本格式
*/
private String format;
/**
* 数据状态
*/
private Integer status;
/**
* 数据序号
*/
private Integer seq;
/**
* ⽂本数据
*/
private String text;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/audio/Oral.java
================================================
package com.ai.spark.endPoint.audio;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Oral {
@JsonProperty("spark_assist")
private Integer sparkAssist;
@JsonProperty("oral_level")
private String oralLevel;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/audio/Pybuf.java
================================================
package com.ai.spark.endPoint.audio;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Pybuf {
/**
* ⽂本编码
*/
private String encoding;
/**
* ⽂本压缩格式
*/
private String compress;
/**
* ⽂本格式
*/
private String format;
/**
* 数据状态
* 0:开始, 1:开始, 2:结束
*/
private Integer status;
/**
* 数据序号
* 最⼩值:0, 最⼤值:9999999
*/
private Integer seq;
/**
* ⽂本数据
* 最⼩尺⼨:0B, 最⼤尺⼨:1048576B
*/
private String text;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/audio/Tts.java
================================================
package com.ai.spark.endPoint.audio;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.*;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Tts {
/**
* 发⾳⼈,必传
*/
private String vcn;
/**
* 语速:0对应默认语速的1/2,100 对应默认语速的2倍
* 最⼩值:0, 最⼤值:100
*/
private Integer speed;
/**
* ⾳量:0是 静⾳,1对 应默认⾳量 1/2,100对应默认⾳量的2倍
* 最⼩值:0, 最⼤值:100
*/
private Integer volume;
/**
* 语调:0对应默认语速的1/2,100对应默认语速的2倍
* 最⼩值:0, 最⼤值:100
*/
private Integer pitch;
/**
* 背景⾳
* 0:⽆背景⾳, 1:内置背景⾳1, 2:内置背景⾳2
*/
private Integer bgs;
/**
* 英⽂发⾳⽅式
* 0:⾃动判断处理,如果不确定将按照英⽂词语拼写处理(缺省),
* 1:所有英⽂按字⺟发⾳,
* 2:⾃动判断处理,如果不确定将按照字⺟朗读
*/
private Integer reg;
/**
* 合成⾳频数字发⾳⽅式
* 0:⾃动判断, 1:完全数值, 2:完全字符串, 3:字符串优先
*/
private Integer rdn;
/**
* 是否返回拼⾳标注
* 0:不返回拼⾳, 1:返回拼⾳(纯⽂本格式,utf8编码)
*/
private Integer rhy;
/**
* 场景
* 0:⽆, 1:散⽂阅读, 2:⼩说阅读, 3:新闻, 4:⼴告, 5:交互
*/
private Integer scn;
/**
* 引擎初始化,是否返回版本信息+时间戳信息
* 0:不返回, 1:返回版本信息+时间戳信息。如XXX.18928127 XXX表示版本号,后接秒为单位的时间戳
*/
private Integer version;
/**
* 控制L5静⾳时⻓,取值范围为 0~10000ms
* 最⼩值:0, 最⼤值:10000
*/
private Integer L5SilLen;
/**
* 段落静⾳时⻓,取值范围为0~10000ms
* 最⼩值:0, 最⼤值:10000
*/
private Integer ParagraphSilLen;
private Audio audio;
private Pybuf pybuf;
@Getter
@AllArgsConstructor
public enum Vcn {
lxx("x4_lingxiaoxuan_oral"),
lfz("x4_lingfeizhe_oral"),
lyz("x4_lingyuzhao_oral");
private String name;
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/audio/req/AudioRequest.java
================================================
package com.ai.spark.endPoint.audio.req;
import com.ai.spark.endPoint.audio.AudioHeader;
import com.ai.spark.endPoint.audio.AudioParameter;
import com.ai.spark.endPoint.audio.AudioPayload;
import com.ai.spark.endPoint.audio.Tts;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class AudioRequest {
/**
* 协议头部,⽤于描述平台特性的参数
*/
private AudioHeader header;
/**
* AI 能⼒功能参数,⽤于控制 AI 引擎特性的开关。
*/
private AudioParameter parameter;
private AudioPayload payload;
public static AudioRequest baseBuild(Tts.Vcn vcn, String appId) {
AudioRequest request = AudioRequest.builder()
.header(AudioHeader.builder().appId(appId).status(0).build())
.parameter(AudioParameter.builder().tts(Tts.builder().vcn(vcn.getName()).build()).build())
.build();
return request;
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/audio/resp/AudioResponse.java
================================================
package com.ai.spark.endPoint.audio.resp;
import com.ai.spark.endPoint.audio.AudioHeader;
import com.ai.spark.endPoint.audio.AudioPayload;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class AudioResponse {
private AudioHeader header;
private AudioPayload payload;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/Chat.java
================================================
package com.ai.spark.endPoint.chat;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Chat {
/**
* 必传
* 指定访问的领域,general指向V1.5版本,generalv2指向V2版本,generalv3指向V3版本 。注意:不同的取值对应的url也不一样!
* 取值为[general,generalv2,generalv3]
*/
private String domain;
/**
* 非必传
* 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高
* 取值范围 (0,1] ,默认值0.5
*/
private Double temperature;
/**
* 非必传
* 模型回答的tokens的最大长度
* V1.5取值为[1,4096]
* V2.0取值为[1,8192],默认为2048。
* V3.0取值为[1,8192],默认为2048。
*/
@JsonProperty("max_tokens")
private Integer maxTokens;
/**
* 非必传
* 从k个候选中随机选择⼀个(⾮等概率)
* 取值为[1,6],默认为4
*/
@JsonProperty("top_k")
private Integer topK;
/**
* 非必传
* 用于关联用户会话,需要保障用户下的唯一性
*/
@JsonProperty("chat_id")
private String chatId;
@Getter
@AllArgsConstructor
public enum General {
general("general"),
generalV2("generalv2"),
generalV3("generalv3");
private String msg;
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/ChatHeader.java
================================================
package com.ai.spark.endPoint.chat;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatHeader {
/**
* 必传
* 应用appid,从开放平台控制台创建的应用中获取
*/
@JsonProperty("app_id")
private String appId;
/**
* 非必传
* 每个用户的id,用于区分不同用户
*/
private String uid;
// 以下是请求返回时所需参数
/**
* 错误码,0表示正常,非0表示出错;
*/
private Integer code;
/**
* 会话是否成功的描述信息
*/
private String message;
/**
* 会话的唯一id,用于讯飞技术人员查询服务端会话日志使用,出现调用错误时建议留存该字段
*/
private String sid;
/**
* 会话状态,取值为[0,1,2];0代表首次结果;1代表中间结果;2代表最后一个结果
*/
private Integer status;
@Getter
@AllArgsConstructor
public enum Code {
SUCCESS(0),
;
private final int value;
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/ChatParameter.java
================================================
package com.ai.spark.endPoint.chat;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatParameter {
private Chat chat;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/ChatPayload.java
================================================
package com.ai.spark.endPoint.chat;
import com.ai.spark.common.Usage;
import com.ai.spark.endPoint.chat.function.Function;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatPayload {
private Message message;
@JsonProperty("functions")
private Function function;
// 以下是请求返回时所需参数
@JsonProperty("choices")
private Choice choice;
private Usage usage;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/ChatText.java
================================================
package com.ai.spark.endPoint.chat;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatText {
/**
* user表示是用户的问题,assistant表示AI的回复
* 取值为[user,assistant]
*/
private String role;
/**
* 用户和AI的对话内容
* 所有content的累计tokens需控制8192以内
*/
private String content;
/**
* 结果序号,取值为[0,10];
*/
private Integer index;
/**
* 数据的类型
*/
@JsonProperty("content_type")
private String contentType;
public static ChatText baseBuild(Role role, String content) {
return ChatText.builder().role(role.getRoleName()).content(content).build();
}
@Getter
@AllArgsConstructor
public enum Role {
USER("user"),
ASSISTANT("assistant"),
;
private String RoleName;
}
@Getter
@AllArgsConstructor
public enum ContentType {
TEXT("text"),
IMAGE("image"),
;
private String type;
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/Choice.java
================================================
package com.ai.spark.endPoint.chat;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;
import java.util.List;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Choice {
/**
* 文本响应状态,取值为[0,1,2]; 0代表首个文本结果;1代表中间文本结果;2代表最后一个文本结果
*/
private Integer status;
/**
* 返回的数据序号,取值为[0,9999999]
*/
private Integer seq;
@JsonProperty("text")
private List texts;
@Getter
public enum Status {
START(0),
ING(1),
END(2),
;
private final int value;
Status(int value) {
this.value = value;
}
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/Message.java
================================================
package com.ai.spark.endPoint.chat;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Message {
@JsonProperty("text")
private List chatTexts;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/document/ChatExtends.java
================================================
package com.ai.spark.endPoint.chat.document;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatExtends {
/**
* wiki 大模型问答模板,在某些场景服务默认的 prompt 回答效果不好时,业务可以考虑通过自定义 prompt 来改善。替换的问题标识,替换的文本内容标识
*/
private String wikiPromptTpl;
/**
* wiki 结果分数阈值,低于这个阈值的结果丢弃。取值范围为(0,1] 参考值为:0.80非常宽松 0.82宽松 0.83标准0.84严格 0.86非常严格。服务会根据问题检索文件列表中内容相关的文段,该值设置的越高,可能丢弃的内容越多,但保留下来的内容越准确;但过高也可能导致无匹配内容
*/
private Float wikiFilterScore;
/**
* 用户问题未匹配到文档内容时,是否使用大模型兜底回答问题
*/
private Boolean sparkWhenWithoutEmbedding;
/**
* 大模型问答时的温度,取值范围 (0,1] ,temperature 越大,大模型回答随机度越高
*/
private Float temperature;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/function/Function.java
================================================
package com.ai.spark.endPoint.chat.function;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Function {
@JsonProperty("functions")
private List functionTexts;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/function/FunctionParameter.java
================================================
package com.ai.spark.endPoint.chat.function;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class FunctionParameter {
/**
* 参数类型
*/
private String type;
/**
* 该内容由用户定义,命中该方法时需要返回哪些参数
*/
private Object properties;
/**
* 该内容由用户定义,命中方法时必须返回的字段
*/
private List required;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/function/FunctionText.java
================================================
package com.ai.spark.endPoint.chat.function;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class FunctionText {
/**
* 用户输入命中后,会返回该名称
*/
private String name;
/**
* 描述function功能即可,越详细越有助于大模型理解该function
*/
private String description;
@JsonProperty("parameters")
private FunctionParameter functionParameter;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/req/ChatRequest.java
================================================
package com.ai.spark.endPoint.chat.req;
import com.ai.spark.endPoint.chat.*;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.Arrays;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatRequest {
@JsonProperty("header")
private ChatHeader chatHeader;
@JsonProperty("parameter")
private ChatParameter chatParameter;
@JsonProperty("payload")
private ChatPayload chatPayload;
public static ChatRequest baseBuild(String question, String appId) {
ChatHeader chatHeader = ChatHeader.builder().appId(appId).build();
Chat chat = Chat.builder().domain(Chat.General.generalV3.getMsg()).build();
ChatParameter chatParameter = ChatParameter.builder().chat(chat).build();
ChatText chatText = ChatText.builder().role(ChatText.Role.USER.getRoleName()).content(question).build();
Message message = Message.builder().chatTexts(new ArrayList<>(Arrays.asList(chatText))).build();
ChatPayload chatPayload = ChatPayload.builder().message(message).build();
return ChatRequest.builder().chatHeader(chatHeader).chatParameter(chatParameter).chatPayload(chatPayload).build();
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/req/DocumentChatRequest.java
================================================
package com.ai.spark.endPoint.chat.req;
import com.ai.spark.endPoint.chat.ChatText;
import com.ai.spark.endPoint.chat.document.ChatExtends;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class DocumentChatRequest {
private ChatExtends chatExtends;
/**
* 必传
* 提问问题检索的文件 id 列表
*/
private List fileIds;
/**
* 问答内容列表,按时间正序,最后一条为最新提问
*/
@JsonProperty("messages")
private List chatTexts;
public static DocumentChatRequest baseBuild(String question, List fileIds) {
return DocumentChatRequest.builder()
.fileIds(fileIds)
.chatTexts(new ArrayList<>(Arrays.asList(ChatText.baseBuild(ChatText.Role.USER, question))))
.build();
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/resp/ChatResponse.java
================================================
package com.ai.spark.endPoint.chat.resp;
import com.ai.spark.endPoint.chat.ChatHeader;
import com.ai.spark.endPoint.chat.ChatPayload;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatResponse {
@JsonProperty("header")
private ChatHeader chatHeader;
@JsonProperty("payload")
private ChatPayload chatPayload;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/chat/resp/DocumentChatResponse.java
================================================
package com.ai.spark.endPoint.chat.resp;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.*;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class DocumentChatResponse {
/**
* 错误码 ,0 标识成功
*/
private Integer code;
/**
* 错误描述
*/
private String content;
/**
* 文档引用,status=99 的时候返回;结构是个 Map,key=文件 id,value=引用的文段列表(对应 fileTrunks 的 index)
*/
private String fileRefer;
/**
* 会话唯一标识
*/
private String sid;
/**
* 会话状态,取值为[0,1,2,99];0 代表首次结果;1 代表中间结果;2 代表最后一个结果;99 代表引用的文档及文段
*/
private Integer status;
@Getter
@AllArgsConstructor
public enum Code {
SUCCESS(0),
;
private final int value;
}
@Getter
public enum Status {
START(0),
ING(1),
END(2),
DOCUMENT(99),
;
private final int value;
Status(int value) {
this.value = value;
}
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/document/Data.java
================================================
package com.ai.spark.endPoint.document;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NoArgsConstructor;
@lombok.Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Data {
/**
* 返回上传的 fileId
*/
private String fileId;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/document/req/FileUploadRequest.java
================================================
package com.ai.spark.endPoint.document.req;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.*;
import lombok.experimental.FieldNameConstants;
import java.io.File;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class FileUploadRequest {
/**
* 要上传的文件
*/
private File file;
/**
* 文件 url (文件和文件 url 必须有一个)
*/
private String url;
/**
* 文件名称,带后缀。文件用 url 的方式,该字段必传;传 file 的话,该字段可不传
*/
private String fileName;
/**
* 文件类型,目前传固定值"wiki"
*/
@Builder.Default
private String fileType = FileType.wiki.getType();
/**
* 文件状态回调地址,文件状态有变动时服务会调用该 url。调用的时候会带上鉴权头,鉴权方式同【接口鉴权】,业务可根据需要是否做鉴权校验
*/
private String callbackUrl;
@Getter
@AllArgsConstructor
public enum FileType {
wiki("wiki");
private String type;
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/document/resp/DocumentSummaryResponse.java
================================================
package com.ai.spark.endPoint.document.resp;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class DocumentSummaryResponse {
private Boolean flag;
private Integer code;
private String sid;
private String desc;
private Object data;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/document/resp/FileUploadResponse.java
================================================
package com.ai.spark.endPoint.document.resp;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class FileUploadResponse {
/**
* 返回上传的 fileId
*/
private Integer code;
/**
* 请求唯一 id,用于问题定位
*/
private String sid;
/**
* 结果描述
*/
private String desc;
/**
* 返回结果
*/
private com.ai.spark.endPoint.document.Data data;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/embedding/Emb.java
================================================
package com.ai.spark.endPoint.embedding;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Emb {
private Feature feature;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/embedding/EmbeddingHeader.java
================================================
package com.ai.spark.endPoint.embedding;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class EmbeddingHeader {
/**
* 必传
* 在平台申请的app id信息
*/
@JsonProperty("app_id")
private String appId;
/**
* 非必传
* 每个用户的id,用于区分不同用户
*/
private String uid;
/**
* 发送状态标识,3为一次性发完
*/
@Builder.Default
private Integer status = 3;
// 以下是请求返回时所需参数
/**
* 错误码,0表示正常,非0表示出错;
*/
private Integer code;
/**
* 会话是否成功的描述信息
*/
private String message;
/**
* 会话的唯一id,用于讯飞技术人员查询服务端会话日志使用,出现调用错误时建议留存该字段
*/
private String sid;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/embedding/EmbeddingMessage.java
================================================
package com.ai.spark.endPoint.embedding;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class EmbeddingMessage {
private String text;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/embedding/EmbeddingParameter.java
================================================
package com.ai.spark.endPoint.embedding;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class EmbeddingParameter {
@Builder.Default
private Emb emb = Emb.builder().feature(Feature.builder().build()).build();
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/embedding/EmbeddingPayload.java
================================================
package com.ai.spark.endPoint.embedding;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class EmbeddingPayload {
@JsonProperty("messages")
private EmbeddingMessage embeddingMessage;
private Feature feature;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/embedding/Feature.java
================================================
package com.ai.spark.endPoint.embedding;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Feature {
@Builder.Default
private String encoding = "utf8";
private String seq;
private String status;
private String text;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/embedding/req/EmbeddingRequest.java
================================================
package com.ai.spark.endPoint.embedding.req;
import com.ai.common.utils.JsonUtils;
import com.ai.spark.endPoint.chat.ChatText;
import com.ai.spark.endPoint.embedding.EmbeddingHeader;
import com.ai.spark.endPoint.embedding.EmbeddingMessage;
import com.ai.spark.endPoint.embedding.EmbeddingParameter;
import com.ai.spark.endPoint.embedding.EmbeddingPayload;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class EmbeddingRequest {
@JsonProperty("header")
private EmbeddingHeader embeddingHeader;
@JsonProperty("parameter")
private EmbeddingParameter embeddingParameter;
@JsonProperty("payload")
private EmbeddingPayload embeddingPayload;
public static EmbeddingRequest baseBuild(ChatText text, String appId) {
return EmbeddingRequest.builder()
.embeddingHeader(EmbeddingHeader.builder().appId(appId).build())
.embeddingParameter(EmbeddingParameter.builder().build())
.embeddingPayload(EmbeddingPayload.builder().embeddingMessage(EmbeddingMessage
.builder()
.text(Base64.getEncoder().encodeToString(JsonUtils.toJson(text).getBytes(StandardCharsets.UTF_8)))
.build()).build()
)
.build();
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/embedding/resp/EmbeddingResponse.java
================================================
package com.ai.spark.endPoint.embedding.resp;
import com.ai.spark.endPoint.embedding.EmbeddingHeader;
import com.ai.spark.endPoint.embedding.EmbeddingPayload;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class EmbeddingResponse {
@JsonProperty("header")
private EmbeddingHeader embeddingHeader;
@JsonProperty("payload")
private EmbeddingPayload embeddingPayload;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/ImageChat.java
================================================
package com.ai.spark.endPoint.images;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageChat {
@Builder.Default
private String domain = "general";
@Builder.Default
private Integer width = 512;
@Builder.Default
private Integer height = 512;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/ImageHeader.java
================================================
package com.ai.spark.endPoint.images;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageHeader {
@JsonProperty("app_id")
private String appId;
private String uid;
// 下面是返回时用到的属性
private Integer code;
private String message;
private String sid;
private Integer status;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/ImageParameter.java
================================================
package com.ai.spark.endPoint.images;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageParameter {
@Builder.Default
@JsonProperty("chat")
private ImageChat imageChat = ImageChat.builder().build();
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/ImagePayload.java
================================================
package com.ai.spark.endPoint.images;
import com.ai.spark.endPoint.chat.Choice;
import com.ai.spark.endPoint.chat.Message;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImagePayload {
private Message message;
@JsonProperty("choices")
private Choice choice;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/ImageUnderstandingChat.java
================================================
package com.ai.spark.endPoint.images;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageUnderstandingChat {
@Builder.Default
private String domain = "general";
private String auditing;
private Double temperature;
@JsonProperty("top_k")
private Integer topK;
@JsonProperty("max_tokens")
private Integer maxTokens;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/ImageUnderstandingParameter.java
================================================
package com.ai.spark.endPoint.images;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageUnderstandingParameter {
@Builder.Default
@JsonProperty("chat")
private ImageUnderstandingChat imageUnderstandingChat = ImageUnderstandingChat.builder().build();
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/ImageUnderstandingPayload.java
================================================
package com.ai.spark.endPoint.images;
import com.ai.spark.common.Usage;
import com.ai.spark.endPoint.chat.Choice;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageUnderstandingPayload {
@JsonProperty("choices")
private Choice choice;
private Usage usage;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/req/ImageCreateRequest.java
================================================
package com.ai.spark.endPoint.images.req;
import com.ai.spark.endPoint.chat.ChatText;
import com.ai.spark.endPoint.chat.Message;
import com.ai.spark.endPoint.images.ImageHeader;
import com.ai.spark.endPoint.images.ImageParameter;
import com.ai.spark.endPoint.images.ImagePayload;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
import java.util.Arrays;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageCreateRequest {
@JsonProperty("header")
private ImageHeader imageHeader;
@JsonProperty("parameter")
private ImageParameter imageParameter;
@JsonProperty("payload")
private ImagePayload imagePayload;
public static ImageCreateRequest baseBuild(String content, String appId) {
return ImageCreateRequest
.builder()
.imageHeader(ImageHeader.builder().appId(appId).build())
.imageParameter(ImageParameter.builder().build())
.imagePayload(ImagePayload.builder().message(Message.builder().chatTexts(Arrays.asList(ChatText.baseBuild(ChatText.Role.USER, content))).build()).build())
.build();
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/req/ImageUnderstandingRequest.java
================================================
package com.ai.spark.endPoint.images.req;
import com.ai.spark.endPoint.chat.ChatText;
import com.ai.spark.endPoint.chat.Message;
import com.ai.spark.endPoint.images.ImageHeader;
import com.ai.spark.endPoint.images.ImagePayload;
import com.ai.spark.endPoint.images.ImageUnderstandingParameter;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageUnderstandingRequest {
@JsonProperty("header")
private ImageHeader imageHeader;
@JsonProperty("parameter")
private ImageUnderstandingParameter imageUnderstandingParameter;
@JsonProperty("payload")
private ImagePayload imagePayload;
public static ImageUnderstandingRequest baseBuild(String content, String appId, File image) {
byte[] data = null;
try (FileInputStream fileInputStream = new FileInputStream(image)) {
data = new byte[fileInputStream.available()];
fileInputStream.read(data); // 读取文件数据到数组中
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
String imageStr = Base64.getEncoder().encodeToString(data); // 将byte数组进行base64编码
ImageHeader imageHeader = ImageHeader.builder().appId(appId).build();
ImageUnderstandingParameter imageUnderstandingParameter = ImageUnderstandingParameter.builder().build();
ChatText chatText = ChatText.baseBuild(ChatText.Role.USER, content);
chatText.setContentType(ChatText.ContentType.TEXT.getType());
ChatText imgText = ChatText.baseBuild(ChatText.Role.USER, imageStr);
imgText.setContentType(ChatText.ContentType.IMAGE.getType());
Message message = Message.builder().chatTexts(new ArrayList<>(Arrays.asList(imgText, chatText))).build();
ImagePayload imagePayload = ImagePayload.builder().message(message).build();
return ImageUnderstandingRequest
.builder()
.imageHeader(imageHeader)
.imagePayload(imagePayload)
.imageUnderstandingParameter(imageUnderstandingParameter).build();
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/resp/ImageCreateResponse.java
================================================
package com.ai.spark.endPoint.images.resp;
import com.ai.spark.endPoint.images.ImageHeader;
import com.ai.spark.endPoint.images.ImagePayload;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageCreateResponse {
@JsonProperty("header")
private ImageHeader imageHeader;
@JsonProperty("payload")
private ImagePayload imagePayload;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/endPoint/images/resp/ImageUnderstandingResponse.java
================================================
package com.ai.spark.endPoint.images.resp;
import com.ai.spark.endPoint.images.ImageHeader;
import com.ai.spark.endPoint.images.ImageUnderstandingPayload;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.FieldNameConstants;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@FieldNameConstants
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ImageUnderstandingResponse {
@JsonProperty("header")
private ImageHeader imageHeader;
@JsonProperty("payload")
private ImageUnderstandingPayload imageUnderstandingPayload;
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/interceptor/BaseUrlInterceptor.java
================================================
package com.ai.spark.interceptor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import java.io.IOException;
@Slf4j
public class BaseUrlInterceptor implements Interceptor {
// TODO 对路径进行拦截
@Override
public Response intercept(Chain chain) throws IOException {
//获取request
Request request = chain.request();
return chain.proceed(request);
}
}
================================================
FILE: ai-spark/src/main/java/com/ai/spark/interceptor/ResponseInterceptor.java
================================================
package com.ai.spark.interceptor;
import com.ai.core.exception.BaseException;
import com.ai.core.exception.Constants;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import java.io.IOException;
/**
* 返回信息拦截器
*/
@Slf4j
public class ResponseInterceptor implements Interceptor {
@Override
public Response intercept(Chain chain) throws IOException {
// 1. 获取 req 和 resp
Request original = chain.request();
Response response = chain.proceed(original);
// 2. 排除webSocket连接,判断返回状态
if (!"websocket".equalsIgnoreCase(response.header("Upgrade"))
&& !"Upgrade".equalsIgnoreCase(response.header("Connection"))
&& !response.isSuccessful()
&& response.body() != null) {
// 2.1 获取返回的错误信息
log.error(response.body().string());
throw new BaseException(Constants.ErrorMsg.RETRY_ERROR);
}
return response;
}
}
================================================
FILE: ai-spark/src/test/java/com/ai/spark/AudioApiTest.java
================================================
package com.ai.spark;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.spark.achieve.ApiData;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.defaults.DefaultSparkSessionFactory;
import com.ai.spark.achieve.standard.session.AggregationSession;
import org.junit.Before;
import java.util.Arrays;
public class AudioApiTest {
private AggregationSession aggregationSession;
@Before
public void before() {
// 1. 创建配置类
Configuration configuration = new Configuration();
configuration.setApiHost("https://spark-api.xf-yun.com");
// 3. 设置鉴权所需的API Key,可设置多个。
ApiData apiData = ApiData.builder()
.apiKey("***********************")
.apiSecret("***********************")
.appId("***********************")
.build();
configuration.setKeyList(Arrays.asList(apiData));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
// configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultSparkSessionFactory factory = new DefaultSparkSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
}
================================================
FILE: ai-spark/src/test/java/com/ai/spark/ChatApiTest.java
================================================
package com.ai.spark;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.spark.achieve.ApiData;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.defaults.DefaultSparkSessionFactory;
import com.ai.spark.achieve.defaults.listener.ChatListener;
import com.ai.spark.achieve.defaults.listener.DocumentChatListener;
import com.ai.spark.achieve.standard.session.AggregationSession;
import com.ai.spark.endPoint.chat.ChatText;
import com.ai.spark.endPoint.chat.req.ChatRequest;
import com.ai.spark.endPoint.chat.req.DocumentChatRequest;
import com.ai.spark.endPoint.chat.resp.ChatResponse;
import com.ai.spark.endPoint.chat.resp.DocumentChatResponse;
import lombok.SneakyThrows;
import org.junit.Before;
import org.junit.Test;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
public class ChatApiTest {
private AggregationSession aggregationSession;
@Before
public void before() {
// 1. 创建配置类
Configuration configuration = new Configuration();
configuration.setApiHost("https://spark-api.xf-yun.com");
// 3. 设置鉴权所需的API Key,可设置多个。
ApiData apiData = ApiData.builder()
.apiKey("**********************")
.apiSecret("**********************")
.appId("**********************")
.build();
configuration.setKeyList(Arrays.asList(apiData));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
// configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultSparkSessionFactory factory = new DefaultSparkSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试聊天功能
*/
@Test
public void test_chat() {
// 创建参数
ChatRequest request = ChatRequest.baseBuild("讲一个笑话", "c8f362b8");
// 设置参数并发起请求,监听事件信息
aggregationSession.getChatSession().chat(new ChatListener(request) {
// 异常处理
@SneakyThrows
@Override
public void onChatError(ChatResponse chatResponse) {
System.out.println(chatResponse);
}
// 获取正常返回的数据
@Override
public void onChatOutput(ChatResponse chatResponse) {
System.out.println(chatResponse);
System.out.print(chatResponse.getChatPayload().getChoice().getTexts().get(0).getContent());
}
// 结束处理
@Override
public void onChatEnd() {
System.out.println("当前会话结束了");
}
});
// 等待会话结束
CountDownLatch countDownLatch = new CountDownLatch(1);
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
/**
* 测试多轮聊天功能
*/
@Test
public void test_chat_multiple() {
// 创建参数
ChatRequest request = ChatRequest.baseBuild("1+1=", "c8f362b8");
// 设置第一轮对话的结果
ChatText chatText1 = ChatText.baseBuild(ChatText.Role.ASSISTANT, "2");
// 设置第二轮对话的问题
ChatText chatText2 = ChatText.baseBuild(ChatText.Role.USER, "2+2=");
// 将对话过程注入到参数当中
request.getChatPayload().getMessage().getChatTexts().add(chatText1);
request.getChatPayload().getMessage().getChatTexts().add(chatText2);
// 设置参数并发起请求,监听事件信息
aggregationSession.getChatSession().chat(new ChatListener(request) {
@Override
public void onChatError(ChatResponse chatResponse) {
System.out.println(chatResponse);
}
@Override
public void onChatOutput(ChatResponse chatResponse) {
System.out.print(chatResponse.getChatPayload().getChoice().getTexts().get(0).getContent());
}
@Override
public void onChatEnd() {
System.out.println("当前会话结束了");
}
});
// 等待会话结束
CountDownLatch countDownLatch = new CountDownLatch(1);
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
/**
* 测试文档对话功能
*/
@Test
public void test_document_chat() {
// 构建参数
DocumentChatRequest documentChatRequest = DocumentChatRequest.baseBuild("总结一下故事一说了什么?", Arrays.asList("c42a68fd31964d43b4f57eab11e9a833"));
// 设置阐述并发起请求
aggregationSession.getChatSession().documentChat(new DocumentChatListener(documentChatRequest) {
@Override
public void onChatError(DocumentChatResponse documentChatResponse) {
System.err.println(documentChatResponse);
}
@Override
public void onChatOutput(DocumentChatResponse documentChatResponse) {
System.out.println(documentChatResponse);
}
@Override
public void onChatEnd() {
System.out.println("当前会话结束了");
}
});
// 等待会话结束
CountDownLatch countDownLatch = new CountDownLatch(1);
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
================================================
FILE: ai-spark/src/test/java/com/ai/spark/DocumentApiTest.java
================================================
package com.ai.spark;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.spark.achieve.ApiData;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.defaults.DefaultSparkSessionFactory;
import com.ai.spark.achieve.standard.session.AggregationSession;
import com.ai.spark.endPoint.document.req.FileUploadRequest;
import com.ai.spark.endPoint.document.resp.DocumentSummaryResponse;
import com.ai.spark.endPoint.document.resp.FileUploadResponse;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.util.Arrays;
public class DocumentApiTest {
private AggregationSession aggregationSession;
@Before
public void before() {
// 1. 创建配置类
Configuration configuration = new Configuration();
configuration.setApiHost("https://spark-api.xf-yun.com");
// 3. 设置鉴权所需的API Key,可设置多个。
ApiData apiData = ApiData.builder()
.apiKey("***********************")
.apiSecret("***********************")
.appId("***********************")
.build();
configuration.setKeyList(Arrays.asList(apiData));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
// configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultSparkSessionFactory factory = new DefaultSparkSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试文件上传功能
*/
@Test
public void test_file_upload() {
// 读取文件
File file = new File("D:\\chatGPT-api\\AI-java\\doc\\test\\test_file_upload.pdf");
// 构建参数
FileUploadRequest request = FileUploadRequest.builder().file(file).build();
// 发起请求获取结果
FileUploadResponse fileUploadResponse = this.aggregationSession.getDocumentSession().fileUpload(request);
System.out.println(fileUploadResponse);
// FileUploadResponse(code=0, sid=28db14303e054046aabd2e96e7e65c51, desc=null, data=Data(fileId=1a477e7e9cb44e23ad4abd98076e3f70))
// FileUploadResponse(code=0, sid=8e4f267415d84827a6ec7a1580e1ce64, desc=null, data=Data(fileId=004c3c6e79bc4d738a7e94a12697ea75))
}
// 文档总结和文档总结查询这两个接口其实只有请求路径不同,类似于异步的效果。
// 调用文档总结接口,并不会直接返回结果,而是通知模型开始进行总结。
// 然后调用文档总结查询接口查询结果,如果结果已经存在的情况下,不管是调用文档总结接口还是文档总结查询接口,返回的数据都是一样的。
/**
* 测试文档总结功能
*/
@Test
public void test_document_summary_start() {
// 传入文档ID,发起请求
DocumentSummaryResponse documentSummaryResponse = this.aggregationSession.getDocumentSession()
.documentSummaryStart("004c3c6e79bc4d738a7e94a12697ea75");
System.out.println(documentSummaryResponse);
}
/**
* 测试文档总结结果查询功能
*/
@Test
public void test_document_summary_query() {
// 传入文档ID,发起请求
DocumentSummaryResponse documentSummaryResponse = this.aggregationSession.getDocumentSession()
.documentSummaryQuery("004c3c6e79bc4d738a7e94a12697ea75");
System.out.println(documentSummaryResponse);
}
}
================================================
FILE: ai-spark/src/test/java/com/ai/spark/EmbeddingApiTest.java
================================================
package com.ai.spark;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.spark.achieve.ApiData;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.defaults.DefaultSparkSessionFactory;
import com.ai.spark.achieve.standard.session.AggregationSession;
import com.ai.spark.endPoint.chat.ChatText;
import com.ai.spark.endPoint.embedding.req.EmbeddingRequest;
import com.ai.spark.endPoint.embedding.resp.EmbeddingResponse;
import org.junit.Before;
import org.junit.Test;
import java.util.Arrays;
public class EmbeddingApiTest {
private AggregationSession aggregationSession;
@Before
public void before() {
// 1. 创建配置类
Configuration configuration = new Configuration();
configuration.setApiHost("https://spark-api.xf-yun.com");
// 3. 设置鉴权所需的API Key,可设置多个。
ApiData apiData = ApiData.builder()
.apiKey("**********************")
.apiSecret("**********************")
.appId("**********************")
.build();
configuration.setKeyList(Arrays.asList(apiData));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
// configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultSparkSessionFactory factory = new DefaultSparkSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试文本嵌入功能
*/
@Test
public void test_embedding() {
// 构造参数信息
ChatText chatText = ChatText.baseBuild(ChatText.Role.USER, "这是一段文字");
EmbeddingRequest request = EmbeddingRequest.baseBuild(chatText, "c8f362b8");
// 发起请求
EmbeddingResponse response = aggregationSession.getEmbeddingSession().embed(request);
System.out.println(response);
}
}
================================================
FILE: ai-spark/src/test/java/com/ai/spark/ImageApiTest.java
================================================
package com.ai.spark;
import com.ai.core.strategy.impl.FirstKeyStrategy;
import com.ai.spark.achieve.ApiData;
import com.ai.spark.achieve.Configuration;
import com.ai.spark.achieve.defaults.DefaultSparkSessionFactory;
import com.ai.spark.achieve.defaults.listener.ImageUnderstandingListener;
import com.ai.spark.achieve.standard.session.AggregationSession;
import com.ai.spark.endPoint.images.req.ImageCreateRequest;
import com.ai.spark.endPoint.images.req.ImageUnderstandingRequest;
import com.ai.spark.endPoint.images.resp.ImageCreateResponse;
import com.ai.spark.endPoint.images.resp.ImageUnderstandingResponse;
import org.junit.Before;
import org.junit.Test;
import javax.imageio.ImageIO;
import javax.xml.bind.DatatypeConverter;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
public class ImageApiTest {
private AggregationSession aggregationSession;
@Before
public void before() {
// 1. 创建配置类
Configuration configuration = new Configuration();
configuration.setApiHost("https://spark-api.xf-yun.com");
// 3. 设置鉴权所需的API Key,可设置多个。
ApiData apiData = ApiData.builder()
.apiKey("**********************")
.apiSecret("**********************")
.appId("**********************")
.build();
configuration.setKeyList(Arrays.asList(apiData));
// 4. 设置请求时 key 的使用策略,默认实现了:随机获取 和 固定第一个Key 两种方式。
configuration.setKeyStrategy(new FirstKeyStrategy());
// configuration.setKeyStrategy(new RandomKeyStrategy());
// 5. 设置代理,若不需要可不设置
// configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)));
// 6. 创建 session 工厂,制造不同场景的 session
DefaultSparkSessionFactory factory = new DefaultSparkSessionFactory(configuration);
this.aggregationSession = factory.openAggregationSession();
}
/**
* 测试图片生成功能
*/
@Test
public void test_image_create() throws IOException {
// 创建请求参数
ImageCreateRequest request = ImageCreateRequest.baseBuild("画一座大山", "c8f362b8");
// 发起请求获取结果
ImageCreateResponse imageCreateResponse = aggregationSession.getImageSession().imageCreate(request);
// 得到结果当中的base64 图片字符串
String content = imageCreateResponse.getImagePayload().getChoice().getTexts().get(0).getContent();
// 转换为byte数组
byte[] imageBytes = DatatypeConverter.parseBase64Binary(content.substring(content.indexOf(",") + 1));
// 读取byte数组,存放到指定文件路径
BufferedImage bufferedImage = ImageIO.read(new ByteArrayInputStream(imageBytes));
File outputFile = new File("D:\\chatGPT-api\\AI-java\\doc\\test\\test_create_image.png");
ImageIO.write(bufferedImage, "png", outputFile);
}
@Test
public void test_image_understanding() {
String filePath = "D:\\chatGPT-api\\AI-java\\doc\\test\\test_create_image.png";
File file = new File(filePath);
ImageUnderstandingRequest request = ImageUnderstandingRequest.baseBuild("这张图片的内容是什么?", "c8f362b8", file);
aggregationSession.getImageSession().imageUnderstanding(request, new ImageUnderstandingListener(request) {
@Override
public void onChatError(ImageUnderstandingResponse imageUnderstandingResponse) {
System.err.println(imageUnderstandingResponse);
}
@Override
public void onChatOutput(ImageUnderstandingResponse imageUnderstandingResponse) {
System.out.println(imageUnderstandingResponse.getImageUnderstandingPayload().getChoice().getTexts().get(0).getContent());
}
@Override
public void onChatEnd() {
}
});
// 等待会话结束
CountDownLatch countDownLatch = new CountDownLatch(1);
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
================================================
FILE: doc/test/test_file_upload.txt
================================================
故事1:小老虎问路
一头骄傲的小老虎在大森林里迷了路。
他走啊走,看到了一头正在蒙头大睡的野猪。小老虎对着野猪的耳朵,大声喊道:“喂,蠢猪,别打呼噜了,告诉我回家的路怎么走吧!”野猪生气地眨了眨眼睛,一言不发,把屁股转向了小老虎,继续睡大觉。小老虎讨了个没趣,无奈地走了。
小老虎问路
路上,他看到一只正在忙碌的小松鼠,于是他用自己的大嗓门儿喊道:“喂,如果你告诉我回家的路怎么走,我就让妈妈给你最好的礼物!”小松鼠就像没听见一样,不搭理小老虎,照样干自己的活儿。小老虎勃然大怒,冲向一只戴眼镜的老灰兔:“嘿,花眼的老兔头,快给我指一条回家的路!”老灰兔慢慢地抬起头,和蔼地说:“森林里的路大家都熟悉,可你这样没礼貌,哪怕你问遍所有的动物,你还是找不到回家的路。”听了老灰兔的话,小老虎猛然醒悟过来——对人说话,要有礼貌才行!
这时,前面过来一只梅花鹿。小老虎走过去,礼貌地说:“梅花鹿你好,请你告诉我回家的路怎么走,好吗?”梅花鹿热情地告诉了小老虎,小老虎高兴的连声说:“谢谢你,梅花鹿,谢谢你!”
小老虎终于安全地回到了自己的小屋。
故事2:鸡和猫调工作
张三家里养了一只猫和一只公鸡。
一天早晨,“喔,喔,喔……”公鸡的长鸣把沉睡的猫叫醒了,猫揉了揉眼说:“死公鸡,没娘教的孩子,吵死啊!”公鸡大声地回骂道:“死懒猫,你才没娘教,太阳晒屁股了,还不起来干活!”猫又回敬了一句:“哼,我晚上辛辛苦苦晚上捉老鼠,白天还不让我多睡会儿?”公鸡无语了。
过了一会儿,猫想出一个好办法,对公鸡说:“不如,我们调换一下工作,你去捉老鼠,我来打鸣叫时钟,你敢不敢?”公鸡自信地说:“啊有不敢?我就怕你不敢?”猫说:“好,那我们就调工作吧!”
第二天早晨,猫早早地起来,站在房顶上,“喵,喵,喵……”地叫着,可是人们都听习惯了公鸡叫天亮,猫的叫声只把附近的人叫醒了,远处的人根本听不见。后来,猫的嗓子都喊破了,被送到宠物医院治疗。
晚上,公鸡来到老鼠洞前,对老鼠说:“鸡大爷来了,快给我出来!”鼠王以为是猫的诡计,就叫了一只小老鼠去侦察,小老鼠侦察了一番回来说:“站在洞门口的是一只公鸡。”于是,鼠们就往老鼠洞里钻。公鸡见一只老鼠都不出来,就把锋利的爪子伸进老鼠洞里,一只老鼠胆大包天串到洞门口对准鸡爪狠狠地咬了一口,鸡痛得哇哇大叫,也被送进了宠物医院。
猫和公鸡在宠物医院相遇了,当他们见彼此都受了伤后,明白了一个道理:尺有所短,寸有所长。
故事3:固执的鱼
在一个池塘里住着一条小鱼。他有一个好朋友,是一只蝌蚪。这条小鱼和蝌蚪总是在一起游泳,一起找食物,一起玩。
一天早上,小鱼吃惊地看见在蝌蚪尾巴的两边长出了一对腿。小鱼问蝌蚪为什么他会有腿。
“我不是鱼,我是一只蝌蚪,一只年幼的青蛙。以后,等我长大了,我就不再呆在这儿了。”蝌蚪回答说。
“你说谎!”小鱼说。
“如果你不相信,就等着瞧吧!”蝌蚪说。
小鱼已经三天没有看见蝌蚪了,他很担心。为了寻找朋友,他搜索了每一个地方,蝌蚪能去哪儿呢?
几天以后,蝌蚪又出现了。小鱼非常高兴,可是他又吃了一惊,蝌蚪又长出了一对腿,哦,还有,尾巴也变短了。
“这几天你到什么地方去了?”小鱼问。
“我去陆地上了,我不是告诉你我在这儿呆不长吗?看,现在我有四条腿了。不久我就要长期在地上生活了。”蝌蚪说,“现在请不要再叫我蝌蚪,就叫我青蛙吧。再见了,鱼!”
小鱼眨眨眼睛,被他朋友的话搞糊涂了。他不能相信听到和看到的,因为从前他的朋友能像鱼一样游泳,从前他没有腿,而现在却不是这样了。
小鱼独自留在池塘里,最后,他变成了一条大鱼。
一天,当这条鱼在池塘里游水寻找食物的时候,一只青蛙跳进水里。他就是鱼的老朋友。看到朋友,鱼非常高兴。
“你去哪儿啦?”鱼问。
“我一直在陆地上。”青蛙说,并且把他在陆地上遇到的事情告诉鱼。
鱼听了青蛙的故事后,问:“在那边谁是你的朋友呀?”
“我有好多朋友,像牛啦,鸟啦,猫啦,还有其它许多动物。”青蛙说。
“我能跟你去陆地吗?我想见见他们。”鱼说。
“那怎么行!你在陆地上不能呼吸,你会死的。”青蛙解释说。
“可是我想去看看牛啦,鸟啦,还有别的你刚才告诉我的朋友们。”鱼请求道。
“你不必亲自去,我给你说说不就行了。”青蛙说。于是青蛙向鱼讲了许多他的陆地朋友的事。鱼试着想象那些动物的样子,但是他总是不满意。
“你在陆地上还看见什么别的吗?”鱼又问。
“还有人,有孩子,有玩具和别的许多东西。”青蛙继续说。
他们一直聊到晚上。鱼很不开心,因为他不能去陆地上看这些奇怪的事,这天晚上他失眠了。他满脑子都是白天听到的各种各样的事情。
第二天早上,鱼去寻找食物。忽然,他看见水面上天空飞鸟的倒影。他太想看看鸟了,就鼓起勇气试着跳到河岸上。鱼一纵身跳上了岸。但是在他睁开眼睛之前,他已经喘不过气了。他开始呻吟起来。算他走运,青蛙正好在附近找吃的。
青蛙马上跳到奄奄一息的鱼的身边。他一点也没耽搁,把大鱼拉进池塘里。鱼一进水,立刻苏醒了。他很惊讶,问青蛙发生了什么事。
青蛙微笑着说:“我跟你说过,你不能到陆地上去。不管你是在陆地上还是在水里都没关系,一切都是美好的,美丽的,为什么你不愿意听我的话!”
“可是只呆在这儿我觉得不满意。”鱼继续说。
“你应该满足了,”青蛙劝鱼说:“没有多少生命能像你一样呆在水里。”
鱼笑了,高兴地在水草间游来游去,他认识到他的朋友说的是真的。
故事4:橘子老虎
秋天,橘子熟了,那黄澄澄的蜜橘挂满枝头,远远望去,就像一个个的小灯笼。一天,一个最大、最沉的橘子,看到同伴被人摘走,伤心地对橘子树说:“妈妈,难道我们橘子生来就应该被人吃掉吗?”
“是啊,孩子,我们的最大愿望就是丰富人类的美好生活。我们身上的果核将会落入泥土,然后新的生命又会破土而出。”
“不,妈妈,我可不愿为他人活着,更不愿被人随意摘取,我要变成一个人见人怕的老虎。”
“孩子,那可不是我们橘子的风格。”
“不嘛!”大橘子纵身一跳,从树上跳到了地面。奇怪的事情发生了:大橘子的身体不断地胀大、胀大,最后圆圆的橘子肚子拉长了,橘皮上竟然浮现出色彩斑斓的花纹来,前面拱出一个脑袋,脑门正中有一个醒目的“王”字,后面露出一条长尾巴来。哈,大橘子竟然变成一只威风凛凛的橘子老虎啦!
橘子老虎非常高兴,她告别妈妈,决定到各地去旅行,让大家见识见识橘子老虎的威风。他翻过一座山冈,看见一只小羊正趴在大树下低声哭泣。他决定吓唬吓唬小羊。他蹑手蹑脚地走上前去,却发现一只大灰狼也在偷偷地逼近小羊。橘子老虎见状忙喊道:“小羊,当心大灰狼!”
小羊一惊,扭头要跑,大灰狼扑上来抓住了小羊。在这千钧一发的时刻,只见橘子老虎一个箭步蹿上前去,对着大灰狼喊道:“大灰狼,快放开他,要不我就撕碎你!”
大灰狼一愣,见一只猛虎向自己扑来,不由两腿直打颤。虽然舍不得到嘴的肥嫩小羊,但这只老虎可不是等闲之辈,只得丢下小羊逃走了。面对橘子老虎,小羊流着眼泪说道:“虎大王,我妈妈得了重病,想吃橘子,您能否等我找到橘子后再吃我呢?”
看看楚楚可怜的小羊,橘子老虎被他爱妈妈的孝心所感动了。他安慰小羊:“别怕,让我帮助你完成心愿。”说着,他便撕开自己的橘子肚皮,掰下一瓣橘子,递给小羊说,“给你妈妈送回去吧。”
小羊感激地说:“您真是天下心眼最好的老虎呀!”
橘子老虎笑了笑,继续往前赶路了。
故事5:狐狸假扮兽王
很早以前,森林中的百兽过着闲逸、安乐的生活。因没有兽王,便商议决定寻找一个有资格作兽王的动物来领导群兽,于是四处寻觅。一天,有只狐狸跑到一家染衣坊寻找食物,不慎掉进了染缸。它惊恐万分,拼命挣扎,等到爬出染缸时,已是精疲力尽。狐狸再也没有心思寻找食物,落荒而逃。它在河边喝水时,见到水中的倒影,忽然发现自己身上的颜色变得美丽异常,与众不同。狐狸自己知道那是在染缸里染上的。正在这时,寻找兽王的动物们发现了它,惊奇地问它是什么动物?是从什么地方来的?狐狸灵机一动,诈称自己是天帝派来作兽王的。群兽从来没有见到过它这样的动物,又听说是天帝派来的,便生起信心,拥立狐狸为王。
当上兽王的狐狸,得意忘形,作威作福。它不但役使所有的野兽为自己做事,还忘乎所以地让狮子当坐骑,四处巡视游玩。照理说狐狸当了兽王,应该对自己的同类特别关照才是,但这兽王并没有这样做,反而痛恨狐群,百般加以折磨。动物们本以为有兽王领导,生活会更加幸福、快乐,没想到却落得如此痛苦。众狐狸更觉得是飞来的横祸,大惑不解,暗地里对兽王进行观察,它们怀疑这天帝所赐的兽王可能是狐狸装扮的。众狐狸找了个机会,偷偷地询问狮子:“每月十五,月圆之日,兽王是否仍要骑着你去游玩?”狮子说:“不,兽王每月十五都给我放假,它总是单独离去。”群狐说:“我们狐狸因为业力的关系,每到十五日就会昏迷一阵,好一会儿才能恢复。你可以在十五日那天跟踪兽王,看它是不是狐狸所扮?”
等到十五日,兽王照常向远处跑去。狮子便悄悄地跟在后面,到了一个山洞里,果然看见兽王象死尸一样倒在地上,昏迷不醒。狮子这才知道动物们都上当受骗了,尤其是自己,居然被狐狸当坐骑戏弄了这么长的时间,狮子羞怒难当,一跃而上将这只狐狸吞食了……
群兽因为没有好好观察,让一只卑劣的狐狸当了兽王。最后的结果是让群兽都受到了莫大的痛苦,那自作聪明的狐狸也自取灭亡。
================================================
FILE: pom.xml
================================================
4.0.0
com.ai
AI-java
pom
1.0
ai-common
ai-openai
ai-spark
ai-baidu
ai-core
AI-java
ai-sdk-java
AI-java
org.apache.maven.plugins
maven-surefire-plugin
2.12.4
true
org.apache.maven.plugins
maven-compiler-plugin
8
8
org.apache.maven.plugins
maven-jar-plugin
2.3.1
true