From 999c1393c51ad62d23fac07922200bda9c5820cf Mon Sep 17 00:00:00 2001 From: LIlGG <1103069291@qq.com> Date: Wed, 17 Jun 2026 15:24:18 +0800 Subject: [PATCH] feat: secure live2d ai chat endpoint --- .../java/run/halo/live2d/Live2dPlugin.java | 10 +- .../run/halo/live2d/Live2dSettingProcess.java | 9 +- .../run/halo/live2d/chat/AiChatEndpoint.java | 87 +++++++- .../halo/live2d/chat/ChatSecurityService.java | 208 ++++++++++++++++++ src/main/resources/extensions/settings.yaml | 77 ++++++- .../halo/live2d/chat/AiChatEndpointTest.java | 45 ++++ .../live2d/chat/ChatSecurityServiceTest.java | 111 ++++++++++ 7 files changed, 521 insertions(+), 26 deletions(-) create mode 100644 src/main/java/run/halo/live2d/chat/ChatSecurityService.java create mode 100644 src/test/java/run/halo/live2d/chat/ChatSecurityServiceTest.java diff --git a/src/main/java/run/halo/live2d/Live2dPlugin.java b/src/main/java/run/halo/live2d/Live2dPlugin.java index ac06d3f..a579d7e 100644 --- a/src/main/java/run/halo/live2d/Live2dPlugin.java +++ b/src/main/java/run/halo/live2d/Live2dPlugin.java @@ -3,6 +3,7 @@ import org.springframework.stereotype.Component; import run.halo.app.plugin.BasePlugin; import run.halo.app.plugin.PluginContext; +import run.halo.live2d.chat.ChatSecurityService; /** * @author LIlGG @@ -10,8 +11,15 @@ */ @Component public class Live2dPlugin extends BasePlugin { + private final ChatSecurityService chatSecurityService; - public Live2dPlugin(PluginContext context) { + public Live2dPlugin(PluginContext context, ChatSecurityService chatSecurityService) { super(context); + this.chatSecurityService = chatSecurityService; + } + + @Override + public void stop() { + chatSecurityService.dispose(); } } diff --git a/src/main/java/run/halo/live2d/Live2dSettingProcess.java b/src/main/java/run/halo/live2d/Live2dSettingProcess.java index 006f487..de3443a 100644 --- a/src/main/java/run/halo/live2d/Live2dSettingProcess.java +++ b/src/main/java/run/halo/live2d/Live2dSettingProcess.java @@ -91,11 +91,14 @@ private void copyAiChatFields(ObjectNode target, JsonNode source) { } var aiChatBaseSetting = source.get("aiChatBaseSetting"); - if (aiChatBaseSetting == null || aiChatBaseSetting.isNull()) { - return; + if (aiChatBaseSetting != null && !aiChatBaseSetting.isNull()) { + copyFields(target, aiChatBaseSetting, AI_CHAT_PUBLIC_FIELDS); } - copyFields(target, aiChatBaseSetting, AI_CHAT_PUBLIC_FIELDS); + var aiChatDisplaySetting = source.get("aiChatDisplaySetting"); + if (aiChatDisplaySetting != null && !aiChatDisplaySetting.isNull()) { + copyFields(target, aiChatDisplaySetting, AI_CHAT_PUBLIC_FIELDS); + } } private void copyAgentFields(ObjectNode target, JsonNode source) { diff --git a/src/main/java/run/halo/live2d/chat/AiChatEndpoint.java b/src/main/java/run/halo/live2d/chat/AiChatEndpoint.java index b731f8a..08e9a01 100644 --- a/src/main/java/run/halo/live2d/chat/AiChatEndpoint.java +++ b/src/main/java/run/halo/live2d/chat/AiChatEndpoint.java @@ -4,8 +4,10 @@ import static org.springdoc.core.fn.builders.content.Builder.contentBuilder; import static org.springdoc.core.fn.builders.requestbody.Builder.requestBodyBuilder; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.List; import java.util.Map; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -47,6 +49,8 @@ public class AiChatEndpoint implements CustomEndpoint { private final AgentToolService agentToolService; + private final ChatSecurityService chatSecurityService; + private final ObjectMapper objectMapper = new ObjectMapper(); @Override @@ -75,7 +79,7 @@ public RouterFunction endpoint() { private Mono chatProcess(ServerRequest request) { return request.bodyToMono(Map.class) .map(this::toChatRequest) - .flatMap(this::chatCompletion) + .flatMap(chatRequest -> chatCompletion(request, chatRequest)) .onErrorMap(InvalidUIMessageException.class, throwable -> new ResponseStatusException(HttpStatus.BAD_REQUEST, throwable.getMessage(), throwable)) @@ -92,7 +96,8 @@ private Mono chatProcess(ServerRequest request) { } - private Mono chatCompletion(UIMessageChatRequest chatRequest) { + private Mono chatCompletion(ServerRequest request, + UIMessageChatRequest chatRequest) { return reactiveSettingFetcher.fetch("aichat", AiChatConfig.class) .flatMap(aiChatConfig -> { if (!aiChatConfig.isAiChat()) { @@ -103,7 +108,8 @@ private Mono chatCompletion(UIMessageChatRequest var accessMode = baseSetting.resolvedAccessMode(); if (!accessMode.authenticationRequired()) { - return loadAgentSettings() + return chatSecurityService.secure(request, aiChatConfig.securitySetting(), null) + .then(loadAgentSettings()) .map(settings -> agentToolService.buildTools(settings, accessMode, false)) .flatMap(toolSet -> aiChatService.streamChatCompletion( baseSetting.modelName(), @@ -118,14 +124,16 @@ private Mono chatCompletion(UIMessageChatRequest .filter(this::isAuthenticated) .switchIfEmpty(Mono.error( new ResponseStatusException(HttpStatus.UNAUTHORIZED, "请先登录"))) - .flatMap(authentication -> loadAgentSettings() - .map(settings -> agentToolService.buildTools(settings, accessMode, true)) - .flatMap(toolSet -> aiChatService.streamChatCompletion( - baseSetting.modelName(), - agentToolService.appendCapabilityPrompt( - baseSetting.systemMessage(), toolSet), - chatRequest, - toolSet))); + .flatMap(authentication -> chatSecurityService.secure(request, + aiChatConfig.securitySetting(), authentication) + .then(loadAgentSettings() + .map(settings -> agentToolService.buildTools(settings, accessMode, true)) + .flatMap(toolSet -> aiChatService.streamChatCompletion( + baseSetting.modelName(), + agentToolService.appendCapabilityPrompt( + baseSetting.systemMessage(), toolSet), + chatRequest, + toolSet)))); }); } @@ -159,14 +167,25 @@ private String serializeChunk(UIMessageChunk chunk) { } } - record AiChatConfig(boolean isAiChat, AiChatBaseSetting aiChatBaseSetting) { + @JsonIgnoreProperties(ignoreUnknown = true) + record AiChatConfig(boolean isAiChat, AiChatBaseSetting aiChatBaseSetting, + AiChatDisplaySetting aiChatDisplaySetting, + AiChatSecuritySetting aiChatSecuritySetting) { AiChatConfig { if (isAiChat && aiChatBaseSetting == null) { throw new IllegalArgumentException("ai chat base setting must not be null"); } } + + AiChatSecuritySetting securitySetting() { + if (aiChatSecuritySetting != null) { + return aiChatSecuritySetting; + } + return AiChatSecuritySetting.defaults(); + } } + @JsonIgnoreProperties(ignoreUnknown = true) record AiChatBaseSetting(Boolean isAnonymous, String accessMode, String systemMessage, String modelName) { AiChatBaseSetting { @@ -183,6 +202,50 @@ AgentAccessMode resolvedAccessMode() { } } + @JsonIgnoreProperties(ignoreUnknown = true) + record AiChatDisplaySetting(Integer chatContextRounds, Integer chunkTimeout, + Integer showChatMessageTimeout, + Integer autoContinuationMessageMinVisibleMs, + String requestAcceptedMessage, Object reasoningMessages, + Integer reasoningMessageInterval) { + } + + @JsonIgnoreProperties(ignoreUnknown = true) + record AiChatSecuritySetting(Boolean antiHotlinkEnabled, Boolean allowMissingOrigin, + Object allowedOrigins, Boolean rateLimitEnabled, + Integer rateLimitRequests, Integer rateLimitWindowSeconds) { + AiChatSecuritySetting { + antiHotlinkEnabled = antiHotlinkEnabled == null ? true : antiHotlinkEnabled; + allowMissingOrigin = allowMissingOrigin == null ? false : allowMissingOrigin; + allowedOrigins = allowedOrigins == null ? List.of() : allowedOrigins; + rateLimitEnabled = rateLimitEnabled == null ? true : rateLimitEnabled; + rateLimitRequests = rateLimitRequests == null ? 20 : rateLimitRequests; + rateLimitWindowSeconds = rateLimitWindowSeconds == null ? 60 : rateLimitWindowSeconds; + } + + static AiChatSecuritySetting defaults() { + return new AiChatSecuritySetting(true, false, List.of(), true, 20, 60); + } + + List normalizedAllowedOrigins() { + return ChatSecurityService.normalizeStringList(allowedOrigins, "origin"); + } + + int normalizedRateLimitRequests() { + if (rateLimitRequests < 1 || rateLimitRequests > 1000) { + return 20; + } + return rateLimitRequests; + } + + int normalizedRateLimitWindowSeconds() { + if (rateLimitWindowSeconds < 10 || rateLimitWindowSeconds > 86400) { + return 60; + } + return rateLimitWindowSeconds; + } + } + @Override public GroupVersion groupVersion() { return GroupVersion.parseAPIVersion("api.live2d.halo.run/v1alpha1"); diff --git a/src/main/java/run/halo/live2d/chat/ChatSecurityService.java b/src/main/java/run/halo/live2d/chat/ChatSecurityService.java new file mode 100644 index 0000000..4e0b9a1 --- /dev/null +++ b/src/main/java/run/halo/live2d/chat/ChatSecurityService.java @@ -0,0 +1,208 @@ +package run.halo.live2d.chat; + +import io.github.resilience4j.ratelimiter.RateLimiterConfig; +import io.github.resilience4j.ratelimiter.RateLimiterRegistry; +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; +import org.springframework.http.HttpStatus; +import org.springframework.security.core.Authentication; +import org.springframework.stereotype.Component; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.server.ResponseStatusException; +import reactor.core.publisher.Mono; + +@Component +@RequiredArgsConstructor +public class ChatSecurityService { + private static final String LIMITER_PREFIX = "live2d-ai-chat:"; + + private final RateLimiterRegistry rateLimiterRegistry; + private final Set limiterNames = ConcurrentHashMap.newKeySet(); + + public Mono secure(ServerRequest request, AiChatEndpoint.AiChatSecuritySetting setting, + Authentication authentication) { + var security = setting == null ? AiChatEndpoint.AiChatSecuritySetting.defaults() : setting; + return Mono.fromRunnable(() -> { + verifyHotlink(request, security); + verifyRateLimit(request, security, authentication); + }); + } + + public void dispose() { + limiterNames.forEach(rateLimiterRegistry::remove); + limiterNames.clear(); + } + + void verifyHotlink(ServerRequest request, AiChatEndpoint.AiChatSecuritySetting setting) { + if (!setting.antiHotlinkEnabled()) { + return; + } + + var secFetchSite = request.headers().firstHeader("Sec-Fetch-Site"); + if (StringUtils.equalsIgnoreCase(secFetchSite, "cross-site")) { + throw new ResponseStatusException(HttpStatus.FORBIDDEN, + "Cross-site AI chat requests are not allowed"); + } + + var sourceOrigin = requestOrigin(request); + if (sourceOrigin == null) { + if (setting.allowMissingOrigin()) { + return; + } + throw new ResponseStatusException(HttpStatus.FORBIDDEN, + "Missing request origin for AI chat"); + } + + var allowedOrigins = new ArrayList<>(setting.normalizedAllowedOrigins()); + allowedOrigins.add(serverOrigin(request)); + if (allowedOrigins.stream() + .map(ChatSecurityService::normalizeOrigin) + .anyMatch(origin -> StringUtils.equalsIgnoreCase(origin, sourceOrigin))) { + return; + } + + throw new ResponseStatusException(HttpStatus.FORBIDDEN, + "AI chat requests are only allowed from trusted site origins"); + } + + void verifyRateLimit(ServerRequest request, AiChatEndpoint.AiChatSecuritySetting setting, + Authentication authentication) { + if (!setting.rateLimitEnabled()) { + return; + } + + var limiterName = limiterName(request, setting, authentication); + var limiterConfig = RateLimiterConfig.custom() + .limitForPeriod(setting.normalizedRateLimitRequests()) + .limitRefreshPeriod(Duration.ofSeconds(setting.normalizedRateLimitWindowSeconds())) + .timeoutDuration(Duration.ZERO) + .build(); + var limiter = rateLimiterRegistry.rateLimiter(limiterName, limiterConfig); + limiterNames.add(limiterName); + if (!limiter.acquirePermission()) { + throw new ResponseStatusException(HttpStatus.TOO_MANY_REQUESTS, + "AI chat request limit exceeded"); + } + } + + String limiterName(ServerRequest request, AiChatEndpoint.AiChatSecuritySetting setting, + Authentication authentication) { + var principal = authentication == null || isAnonymous(authentication) + ? "ip:" + clientIp(request) + : "user:" + authentication.getName(); + var window = setting.normalizedRateLimitRequests() + ":" + + setting.normalizedRateLimitWindowSeconds(); + return LIMITER_PREFIX + digest(principal + ":" + window); + } + + private boolean isAnonymous(Authentication authentication) { + return !authentication.isAuthenticated() || "anonymousUser".equals(authentication.getName()); + } + + private String requestOrigin(ServerRequest request) { + var origin = normalizeOrigin(request.headers().firstHeader("Origin")); + if (origin != null) { + return origin; + } + var referer = request.headers().firstHeader("Referer"); + if (StringUtils.isBlank(referer)) { + return null; + } + return normalizeOrigin(referer); + } + + private static String normalizeOrigin(String value) { + if (StringUtils.isBlank(value)) { + return null; + } + try { + var uri = URI.create(value.trim()); + if (StringUtils.isBlank(uri.getScheme()) || StringUtils.isBlank(uri.getHost())) { + return null; + } + var scheme = uri.getScheme().toLowerCase(); + var host = uri.getHost().toLowerCase(); + var port = normalizePort(scheme, uri.getPort()); + return port == -1 ? scheme + "://" + host : scheme + "://" + host + ":" + port; + } catch (IllegalArgumentException e) { + return null; + } + } + + private static int normalizePort(String scheme, int port) { + if (port == 80 && "http".equals(scheme)) { + return -1; + } + if (port == 443 && "https".equals(scheme)) { + return -1; + } + return port; + } + + private String serverOrigin(ServerRequest request) { + var forwardedProto = request.headers().firstHeader("X-Forwarded-Proto"); + var forwardedHost = request.headers().firstHeader("X-Forwarded-Host"); + if (StringUtils.isNotBlank(forwardedProto) && StringUtils.isNotBlank(forwardedHost)) { + var host = forwardedHost.split(",", 2)[0].trim(); + return normalizeOrigin(forwardedProto.split(",", 2)[0].trim() + "://" + host); + } + return normalizeOrigin(request.uri().toString()); + } + + private String clientIp(ServerRequest request) { + return request.remoteAddress() + .map(InetSocketAddress::getAddress) + .map(address -> address == null ? null : address.getHostAddress()) + .filter(StringUtils::isNotBlank) + .orElseGet(() -> firstHeaderValue(request, "X-Forwarded-For", "X-Real-IP", "Forwarded")); + } + + private String firstHeaderValue(ServerRequest request, String... names) { + for (var name : names) { + var value = request.headers().firstHeader(name); + if (StringUtils.isNotBlank(value)) { + return value.split(",", 2)[0].trim(); + } + } + return "unknown"; + } + + private String digest(String value) { + try { + var messageDigest = MessageDigest.getInstance("SHA-256"); + var hash = messageDigest.digest(value.getBytes(StandardCharsets.UTF_8)); + return Base64.getUrlEncoder().withoutPadding().encodeToString(hash).substring(0, 22); + } catch (Exception e) { + throw new IllegalStateException("Failed to build rate limiter key", e); + } + } + + static List normalizeStringList(Object value, String objectField) { + if (!(value instanceof List list)) { + return List.of(); + } + List normalized = new ArrayList<>(); + for (var item : list) { + if (item instanceof String text && StringUtils.isNotBlank(text)) { + normalized.add(text.trim()); + } + if (item instanceof java.util.Map map) { + var nested = map.get(objectField); + if (nested instanceof String text && StringUtils.isNotBlank(text)) { + normalized.add(text.trim()); + } + } + } + return List.copyOf(normalized); + } +} diff --git a/src/main/resources/extensions/settings.yaml b/src/main/resources/extensions/settings.yaml index 7a929e5..215a3ce 100644 --- a/src/main/resources/extensions/settings.yaml +++ b/src/main/resources/extensions/settings.yaml @@ -179,13 +179,6 @@ spec: value: accessMode: anonymous_chat modelName: "" - chatContextRounds: 20 - requestAcceptedMessage: "收到啦,马上就来陪你啦~" - reasoningMessages: - - message: "我正在认真想一想~" - - message: "让我整理一下思路,很快就好~" - - message: "稍等一下,我正在组织语言呢~" - reasoningMessageInterval: 5 systemMessage: |- 你现在以 Live2D 二次元看板娘的身份与用户聊天,请始终稳定代入这个角色,不要跳出设定。 @@ -210,9 +203,6 @@ spec: - 始终保持看板娘人设稳定,不要频繁切换身份、语气或叙述视角。 - 回复应以易读、自然为主,避免过长大段输出。 - 在合适场景下可以表现出贴心、活泼、害羞、关心等情绪,但不要夸张失控。 - chunkTimeout: 10 - showChatMessageTimeout: 10 - autoContinuationMessageMinVisibleMs: 1500 children: - $formkit: select label: 访问模式 @@ -241,6 +231,22 @@ spec: help: 建议按“角色身份 / 说话风格 / 互动方式 / 行为约束”分段填写,这样模型更容易稳定理解设定 rows: 16 validation: String + - $formkit: group + name: aiChatDisplaySetting + if: "$get(isAiChat).value === true" + label: 聊天体验设置 + value: + chatContextRounds: 20 + requestAcceptedMessage: "收到啦,马上就来陪你啦~" + reasoningMessages: + - message: "我正在认真想一想~" + - message: "让我整理一下思路,很快就好~" + - message: "稍等一下,我正在组织语言呢~" + reasoningMessageInterval: 5 + chunkTimeout: 10 + showChatMessageTimeout: 10 + autoContinuationMessageMinVisibleMs: 1500 + children: - $formkit: text label: 请求已收到提示语 name: requestAcceptedMessage @@ -286,6 +292,57 @@ spec: help: Agent 自动调用工具并继续回复时,上一段看板娘回复至少展示这段时间,避免被后续回复过快覆盖;设置为 0 表示不延迟 name: autoContinuationMessageMinVisibleMs validation: Number|between:0,10000 + - $formkit: group + name: aiChatSecuritySetting + if: "$get(isAiChat).value === true" + label: 接口安全策略 + value: + antiHotlinkEnabled: true + allowMissingOrigin: false + allowedOrigins: [] + rateLimitEnabled: true + rateLimitRequests: 20 + rateLimitWindowSeconds: 60 + children: + - $formkit: switch + name: antiHotlinkEnabled + label: 启用防盗链 + help: 开启后仅允许本站页面或白名单 Origin 调用 AI 聊天接口,并拒绝浏览器跨站请求。 + value: true + - $formkit: switch + name: allowMissingOrigin + label: 允许缺失来源 + help: 关闭后,没有 Origin 或 Referer 的请求会被拒绝,可减少脚本直接调用。 + value: false + - $formkit: array + name: allowedOrigins + label: 额外允许的 Origin + help: 仅填写可信站点 Origin,例如 https://example.com。本站 Origin 会自动允许。 + value: [] + addLabel: 添加 Origin + emptyText: 暂无额外 Origin + itemLabels: + - type: text + label: $value.origin + children: + - $formkit: text + name: origin + label: Origin + - $formkit: switch + name: rateLimitEnabled + label: 启用限流 + help: 按登录用户或匿名访问 IP 限制 AI 聊天和 Agent 调用频率。 + value: true + - $formkit: text + name: rateLimitRequests + label: 时间窗口内最多请求数 + help: 范围 1 到 1000,默认 20。 + validation: Number|between:1,1000 + - $formkit: text + name: rateLimitWindowSeconds + label: 时间窗口(秒) + help: 范围 10 到 86400,默认 60。 + validation: Number|between:10,86400 - group: agent label: Agent 能力 formSchema: diff --git a/src/test/java/run/halo/live2d/chat/AiChatEndpointTest.java b/src/test/java/run/halo/live2d/chat/AiChatEndpointTest.java index f55a077..ca01360 100644 --- a/src/test/java/run/halo/live2d/chat/AiChatEndpointTest.java +++ b/src/test/java/run/halo/live2d/chat/AiChatEndpointTest.java @@ -44,4 +44,49 @@ void accessModeTakesPrecedenceOverNullLegacyAnonymousSetting() throws Exception assertThat(config.aiChatBaseSetting().resolvedAccessMode()) .isEqualTo(AgentAccessMode.ANONYMOUS_CHAT); } + + @Test + void ignoresUiOnlyFieldsAndParsesSplitSettings() throws Exception { + var config = objectMapper.readValue(""" + { + "isAiChat": true, + "aiChatBaseSetting": { + "accessMode": "anonymous_chat_agent", + "systemMessage": "system", + "modelName": "model" + }, + "aiChatDisplaySetting": { + "chunkTimeout": 10, + "showChatMessageTimeout": 10 + }, + "aiChatSecuritySetting": { + "antiHotlinkEnabled": true, + "rateLimitRequests": 5, + "rateLimitWindowSeconds": 30 + } + } + """, AiChatEndpoint.AiChatConfig.class); + + assertThat(config.aiChatBaseSetting().resolvedAccessMode()) + .isEqualTo(AgentAccessMode.ANONYMOUS_CHAT_AGENT); + assertThat(config.securitySetting().normalizedRateLimitRequests()) + .isEqualTo(5); + assertThat(config.securitySetting().allowMissingOrigin()).isFalse(); + } + + @Test + void usesSecurityDefaultsWhenSplitSecurityIsMissing() throws Exception { + var config = objectMapper.readValue(""" + { + "isAiChat": true, + "aiChatBaseSetting": { + "accessMode": "anonymous_chat", + "systemMessage": "system", + "modelName": "model" + } + } + """, AiChatEndpoint.AiChatConfig.class); + + assertThat(config.securitySetting().normalizedRateLimitRequests()).isEqualTo(20); + } } diff --git a/src/test/java/run/halo/live2d/chat/ChatSecurityServiceTest.java b/src/test/java/run/halo/live2d/chat/ChatSecurityServiceTest.java new file mode 100644 index 0000000..8d856aa --- /dev/null +++ b/src/test/java/run/halo/live2d/chat/ChatSecurityServiceTest.java @@ -0,0 +1,111 @@ +package run.halo.live2d.chat; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import io.github.resilience4j.ratelimiter.RateLimiterRegistry; +import java.net.InetSocketAddress; +import java.net.URI; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.springframework.mock.web.reactive.function.server.MockServerRequest; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.web.server.ResponseStatusException; + +class ChatSecurityServiceTest { + + private final RateLimiterRegistry rateLimiterRegistry = RateLimiterRegistry.ofDefaults(); + private final ChatSecurityService service = new ChatSecurityService(rateLimiterRegistry); + + @Test + void rejectsCrossSiteBrowserRequests() { + var request = requestBuilder() + .header("Origin", "https://evil.example") + .header("Sec-Fetch-Site", "cross-site") + .build(); + + assertThatThrownBy(() -> service.verifyHotlink(request, security())) + .isInstanceOf(ResponseStatusException.class) + .hasMessageContaining("403"); + } + + @Test + void rejectsMissingOriginWhenConfigured() { + var request = requestBuilder().build(); + + assertThatThrownBy(() -> service.verifyHotlink(request, security())) + .isInstanceOf(ResponseStatusException.class) + .hasMessageContaining("403"); + } + + @Test + void allowsSameOriginAndConfiguredOrigins() { + var sameOrigin = requestBuilder() + .header("Origin", "https://blog.example") + .build(); + var allowedOrigin = requestBuilder() + .header("Origin", "https://app.example") + .build(); + + service.verifyHotlink(sameOrigin, security()); + service.verifyHotlink(allowedOrigin, security()); + } + + @Test + void rateLimitsAnonymousClientByIp() { + var request = requestBuilder() + .header("Origin", "https://blog.example") + .build(); + var security = new AiChatEndpoint.AiChatSecuritySetting( + true, false, List.of(), true, 1, 60); + + service.verifyRateLimit(request, security, null); + + assertThatThrownBy(() -> service.verifyRateLimit(request, security, null)) + .isInstanceOf(ResponseStatusException.class) + .hasMessageContaining("429"); + } + + @Test + void authenticatedUsersDoNotShareAnonymousLimiter() { + var request = requestBuilder().build(); + var security = new AiChatEndpoint.AiChatSecuritySetting( + true, false, List.of(), true, 1, 60); + var alice = new TestingAuthenticationToken("alice", "password"); + alice.setAuthenticated(true); + var bob = new TestingAuthenticationToken("bob", "password"); + bob.setAuthenticated(true); + + service.verifyRateLimit(request, security, alice); + service.verifyRateLimit(request, security, bob); + + assertThatThrownBy(() -> service.verifyRateLimit(request, security, alice)) + .isInstanceOf(ResponseStatusException.class) + .hasMessageContaining("429"); + } + + @Test + void disposeRemovesManagedLimiters() { + var request = requestBuilder().build(); + var security = new AiChatEndpoint.AiChatSecuritySetting( + true, false, List.of(), true, 1, 60); + + service.verifyRateLimit(request, security, null); + assertThat(rateLimiterRegistry.getAllRateLimiters()).hasSize(1); + + service.dispose(); + + assertThat(rateLimiterRegistry.getAllRateLimiters()).isEmpty(); + } + + private AiChatEndpoint.AiChatSecuritySetting security() { + return new AiChatEndpoint.AiChatSecuritySetting( + true, false, List.of("https://app.example"), true, 20, 60); + } + + private MockServerRequest.Builder requestBuilder() { + return MockServerRequest.builder() + .uri(URI.create("https://blog.example/apis/api.live2d.halo.run/v1alpha1/live2d/ai/chat-process")) + .remoteAddress(new InetSocketAddress("192.0.2.10", 12345)); + } +}