baichuan stream support

This commit is contained in:
robin
2023-11-01 23:13:35 +08:00
parent 824237d1ca
commit b24cf12074
2 changed files with 24 additions and 7 deletions

View File

@ -63,13 +63,16 @@ public class BaichuanAIClient {
public static void refresh() {
String apiKey = "";
String apiHost = "https://api.baichuan-ai.com/v1/chat/";
String apiHost = "https://api.baichuan-ai.com/v1/stream/chat";
String model = "Baichuan2-53B";
String secretKey = "";
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config apiHostConfig = configService.find(BAICHUAN_HOST).getData();
if (apiHostConfig != null && StringUtils.isNotBlank(apiHostConfig.getContent())) {
apiHost = apiHostConfig.getContent();
if (apiHost.endsWith("/")) {
apiHost = apiHost.substring(0, apiHost.length() - 1);
}
}
Config config = configService.find(BAICHUAN_API_KEY).getData();
if (config != null && StringUtils.isNotBlank(config.getContent())) {

View File

@ -10,16 +10,15 @@ import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import okio.BufferedSource;
import org.apache.commons.collections4.CollectionUtils;
import org.jetbrains.annotations.NotNull;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
@ -195,7 +194,6 @@ public class BaichuanAIStreamClient {
chatCompletionsOptions.setModel(this.model);
chatCompletionsOptions.setMessages(chatMessages);
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper();
String requestBody = mapper.writeValueAsString(chatCompletionsOptions);
Request request = new Request.Builder()
@ -203,7 +201,23 @@ public class BaichuanAIStreamClient {
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
.build();
//创建事件
EventSource eventSource = factory.newEventSource(request, eventSourceListener);
// 发送请求并处理响应
try (Response response = this.okHttpClient.newCall(request).execute()) {
if (!response.isSuccessful()) {
throw new IOException("Unexpected code " + response);
}
// 读取并输出响应数据
BufferedSource source = response.body().source();
while (!source.exhausted()) {
String content = source.readUtf8Line();
eventSourceListener.onEvent(null, "[DATA]", null, content);
}
eventSourceListener.onEvent(null, "[DONE]", null, "[DONE]");
} catch (Exception e) {
log.error("baichuan ai error", e);
}
log.info("finish invoking baichuan ai");
} catch (Exception e) {
log.error("baichuan ai error", e);