Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/main/java/run/halo/live2d/Live2dPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,23 @@
import org.springframework.stereotype.Component;
import run.halo.app.plugin.BasePlugin;
import run.halo.app.plugin.PluginContext;
import run.halo.live2d.chat.ChatSecurityService;

/**
* @author LIlGG
* @since 2022-11-30
*/
@Component
public class Live2dPlugin extends BasePlugin {
private final ChatSecurityService chatSecurityService;

public Live2dPlugin(PluginContext context) {
public Live2dPlugin(PluginContext context, ChatSecurityService chatSecurityService) {
super(context);
this.chatSecurityService = chatSecurityService;
}

@Override
public void stop() {
chatSecurityService.dispose();
}
}
9 changes: 6 additions & 3 deletions src/main/java/run/halo/live2d/Live2dSettingProcess.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,14 @@ private void copyAiChatFields(ObjectNode target, JsonNode source) {
}

var aiChatBaseSetting = source.get("aiChatBaseSetting");
if (aiChatBaseSetting == null || aiChatBaseSetting.isNull()) {
return;
if (aiChatBaseSetting != null && !aiChatBaseSetting.isNull()) {
copyFields(target, aiChatBaseSetting, AI_CHAT_PUBLIC_FIELDS);
}

copyFields(target, aiChatBaseSetting, AI_CHAT_PUBLIC_FIELDS);
var aiChatDisplaySetting = source.get("aiChatDisplaySetting");
if (aiChatDisplaySetting != null && !aiChatDisplaySetting.isNull()) {
copyFields(target, aiChatDisplaySetting, AI_CHAT_PUBLIC_FIELDS);
}
}

private void copyAgentFields(ObjectNode target, JsonNode source) {
Expand Down
87 changes: 75 additions & 12 deletions src/main/java/run/halo/live2d/chat/AiChatEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import static org.springdoc.core.fn.builders.content.Builder.contentBuilder;
import static org.springdoc.core.fn.builders.requestbody.Builder.requestBodyBuilder;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -47,6 +49,8 @@ public class AiChatEndpoint implements CustomEndpoint {

private final AgentToolService agentToolService;

private final ChatSecurityService chatSecurityService;

private final ObjectMapper objectMapper = new ObjectMapper();

@Override
Expand Down Expand Up @@ -75,7 +79,7 @@ public RouterFunction<ServerResponse> endpoint() {
private Mono<ServerResponse> chatProcess(ServerRequest request) {
return request.bodyToMono(Map.class)
.map(this::toChatRequest)
.flatMap(this::chatCompletion)
.flatMap(chatRequest -> chatCompletion(request, chatRequest))
.onErrorMap(InvalidUIMessageException.class,
throwable -> new ResponseStatusException(HttpStatus.BAD_REQUEST,
throwable.getMessage(), throwable))
Expand All @@ -92,7 +96,8 @@ private Mono<ServerResponse> chatProcess(ServerRequest request) {
}


private Mono<UIMessageStreamResponse> chatCompletion(UIMessageChatRequest<Void> chatRequest) {
private Mono<UIMessageStreamResponse> chatCompletion(ServerRequest request,
UIMessageChatRequest<Void> chatRequest) {
return reactiveSettingFetcher.fetch("aichat", AiChatConfig.class)
.flatMap(aiChatConfig -> {
if (!aiChatConfig.isAiChat()) {
Expand All @@ -103,7 +108,8 @@ private Mono<UIMessageStreamResponse> chatCompletion(UIMessageChatRequest<Void>
var accessMode = baseSetting.resolvedAccessMode();

if (!accessMode.authenticationRequired()) {
return loadAgentSettings()
return chatSecurityService.secure(request, aiChatConfig.securitySetting(), null)
.then(loadAgentSettings())
.map(settings -> agentToolService.buildTools(settings, accessMode, false))
.flatMap(toolSet -> aiChatService.streamChatCompletion(
baseSetting.modelName(),
Expand All @@ -118,14 +124,16 @@ private Mono<UIMessageStreamResponse> chatCompletion(UIMessageChatRequest<Void>
.filter(this::isAuthenticated)
.switchIfEmpty(Mono.error(
new ResponseStatusException(HttpStatus.UNAUTHORIZED, "请先登录")))
.flatMap(authentication -> loadAgentSettings()
.map(settings -> agentToolService.buildTools(settings, accessMode, true))
.flatMap(toolSet -> aiChatService.streamChatCompletion(
baseSetting.modelName(),
agentToolService.appendCapabilityPrompt(
baseSetting.systemMessage(), toolSet),
chatRequest,
toolSet)));
.flatMap(authentication -> chatSecurityService.secure(request,
aiChatConfig.securitySetting(), authentication)
.then(loadAgentSettings()
.map(settings -> agentToolService.buildTools(settings, accessMode, true))
.flatMap(toolSet -> aiChatService.streamChatCompletion(
baseSetting.modelName(),
agentToolService.appendCapabilityPrompt(
baseSetting.systemMessage(), toolSet),
chatRequest,
toolSet))));
});
}

Expand Down Expand Up @@ -159,14 +167,25 @@ private String serializeChunk(UIMessageChunk chunk) {
}
}

record AiChatConfig(boolean isAiChat, AiChatBaseSetting aiChatBaseSetting) {
@JsonIgnoreProperties(ignoreUnknown = true)
record AiChatConfig(boolean isAiChat, AiChatBaseSetting aiChatBaseSetting,
AiChatDisplaySetting aiChatDisplaySetting,
AiChatSecuritySetting aiChatSecuritySetting) {
AiChatConfig {
if (isAiChat && aiChatBaseSetting == null) {
throw new IllegalArgumentException("ai chat base setting must not be null");
}
}

AiChatSecuritySetting securitySetting() {
if (aiChatSecuritySetting != null) {
return aiChatSecuritySetting;
}
return AiChatSecuritySetting.defaults();
}
}

@JsonIgnoreProperties(ignoreUnknown = true)
record AiChatBaseSetting(Boolean isAnonymous, String accessMode, String systemMessage,
String modelName) {
AiChatBaseSetting {
Expand All @@ -183,6 +202,50 @@ AgentAccessMode resolvedAccessMode() {
}
}

@JsonIgnoreProperties(ignoreUnknown = true)
record AiChatDisplaySetting(Integer chatContextRounds, Integer chunkTimeout,
Integer showChatMessageTimeout,
Integer autoContinuationMessageMinVisibleMs,
String requestAcceptedMessage, Object reasoningMessages,
Integer reasoningMessageInterval) {
}

@JsonIgnoreProperties(ignoreUnknown = true)
record AiChatSecuritySetting(Boolean antiHotlinkEnabled, Boolean allowMissingOrigin,
Object allowedOrigins, Boolean rateLimitEnabled,
Integer rateLimitRequests, Integer rateLimitWindowSeconds) {
AiChatSecuritySetting {
antiHotlinkEnabled = antiHotlinkEnabled == null ? true : antiHotlinkEnabled;
allowMissingOrigin = allowMissingOrigin == null ? false : allowMissingOrigin;
allowedOrigins = allowedOrigins == null ? List.of() : allowedOrigins;
rateLimitEnabled = rateLimitEnabled == null ? true : rateLimitEnabled;
rateLimitRequests = rateLimitRequests == null ? 20 : rateLimitRequests;
rateLimitWindowSeconds = rateLimitWindowSeconds == null ? 60 : rateLimitWindowSeconds;
}

static AiChatSecuritySetting defaults() {
return new AiChatSecuritySetting(true, false, List.of(), true, 20, 60);
}

List<String> normalizedAllowedOrigins() {
return ChatSecurityService.normalizeStringList(allowedOrigins, "origin");
}

int normalizedRateLimitRequests() {
if (rateLimitRequests < 1 || rateLimitRequests > 1000) {
return 20;
}
return rateLimitRequests;
}

int normalizedRateLimitWindowSeconds() {
if (rateLimitWindowSeconds < 10 || rateLimitWindowSeconds > 86400) {
return 60;
}
return rateLimitWindowSeconds;
}
}

@Override
public GroupVersion groupVersion() {
return GroupVersion.parseAPIVersion("api.live2d.halo.run/v1alpha1");
Expand Down
Loading