From 57b2cd345c510731b29296f08cfd110d5d6f0231 Mon Sep 17 00:00:00 2001 From: Josh Long <54473+joshlong@users.noreply.github.com> Date: Mon, 5 May 2025 19:42:49 -0700 Subject: [PATCH] wip Signed-off-by: Josh Long <54473+joshlong@users.noreply.github.com> --- spring-ai-model/pom.xml | 5 ++ .../memory/InMemoryChatMemoryRepository.java | 5 +- .../chat/memory/MessageWindowChatMemory.java | 3 +- ...EncryptingChatMemoryBeanPostProcessor.java | 32 +++++++ .../EncryptingChatMemoryRepository.java | 84 ++++++++++++++++++ .../ai/chat/messages/AssistantMessage.java | 87 ++++++++++++++++++- 6 files changed, 209 insertions(+), 7 deletions(-) create mode 100644 spring-ai-model/src/main/java/org/springframework/ai/chat/memory/encryption/EncryptingChatMemoryBeanPostProcessor.java create mode 100644 spring-ai-model/src/main/java/org/springframework/ai/chat/memory/encryption/EncryptingChatMemoryRepository.java diff --git a/spring-ai-model/pom.xml b/spring-ai-model/pom.xml index 167d1273502..0ae3c7f3dc1 100644 --- a/spring-ai-model/pom.xml +++ b/spring-ai-model/pom.xml @@ -64,6 +64,11 @@ true + + org.springframework.security + spring-security-config + true + org.springframework diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java index 290ccfb4174..50f33966f18 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java @@ -28,11 +28,12 @@ * An in-memory implementation of {@link ChatMemoryRepository}. * * @author Thomas Vitale + * @author Josh Long * @since 1.0.0 */ -public final class InMemoryChatMemoryRepository implements ChatMemoryRepository { +public class InMemoryChatMemoryRepository implements ChatMemoryRepository { - Map> chatMemoryStore = new ConcurrentHashMap<>(); + private final Map> chatMemoryStore = new ConcurrentHashMap<>(); @Override public List findConversationIds() { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java index 8b289d1b276..563c61bc090 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java @@ -36,9 +36,10 @@ * {@link SystemMessage} messages are preserved while evicting other types of messages. * * @author Thomas Vitale + * @author Josh Long * @since 1.0.0 */ -public final class MessageWindowChatMemory implements ChatMemory { +public class MessageWindowChatMemory implements ChatMemory { private static final int DEFAULT_MAX_MESSAGES = 200; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/encryption/EncryptingChatMemoryBeanPostProcessor.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/encryption/EncryptingChatMemoryBeanPostProcessor.java new file mode 100644 index 00000000000..c9266a4ba26 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/encryption/EncryptingChatMemoryBeanPostProcessor.java @@ -0,0 +1,32 @@ +package org.springframework.ai.chat.memory.encryption; + +import org.jetbrains.annotations.NotNull; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.config.BeanPostProcessor; +import org.springframework.security.crypto.encrypt.TextEncryptor; + +/** + * Uses a configured {@link TextEncryptor text encryptor} to encrypt values before writes, + * and decode those values from the read operations. + * + * @author Josh Long + */ +public class EncryptingChatMemoryBeanPostProcessor implements BeanPostProcessor { + + private final TextEncryptor encryptor; + + public EncryptingChatMemoryBeanPostProcessor(TextEncryptor encryptor) { + this.encryptor = encryptor; + } + + @Override + public Object postProcessAfterInitialization(@NotNull Object bean, @NotNull String beanName) throws BeansException { + + if (bean instanceof ChatMemoryRepository cmr && !(cmr instanceof EncryptingChatMemoryRepository)) { + return new EncryptingChatMemoryRepository(cmr, encryptor); + } + return BeanPostProcessor.super.postProcessAfterInitialization(bean, beanName); + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/encryption/EncryptingChatMemoryRepository.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/encryption/EncryptingChatMemoryRepository.java new file mode 100644 index 00000000000..f8f76584aaf --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/encryption/EncryptingChatMemoryRepository.java @@ -0,0 +1,84 @@ +package org.springframework.ai.chat.memory.encryption; + +import org.jetbrains.annotations.NotNull; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.security.crypto.encrypt.TextEncryptor; + +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * + * Wraps {@link ChatMemoryRepository a ChatMemoryRepository}, encrypting and decrypting + * reads and writes respectively using a Spring Security {@link TextEncryptor text + * encryptor}. + * + * @author Josh Long + */ +public class EncryptingChatMemoryRepository implements ChatMemoryRepository { + + private final ChatMemoryRepository target; + + private final TextEncryptor textEncryptor; + + public EncryptingChatMemoryRepository(ChatMemoryRepository target, TextEncryptor textEncryptor) { + this.target = target; + this.textEncryptor = textEncryptor; + } + + private Message transform(Message message, Function function) { + + var transformedText = function.apply(message.getText()); + + // todo is there a case to be made that we should seal the message hierarchy? + if (message instanceof SystemMessage systemMessage) { + return systemMessage.mutate().text(transformedText).build(); + } + + if (message instanceof UserMessage userMessage) { + return userMessage.mutate().text(transformedText).build(); + } + + if (message instanceof AssistantMessage assistantMessage) { + return assistantMessage.mutate().text(transformedText).build(); + } + + return message; + } + + private Message decrypt(Message message) { + return this.transform(message, this.textEncryptor::decrypt); + } + + private Message encrypt(Message message) { + return this.transform(message, this.textEncryptor::encrypt); + } + + @NotNull + @Override + public List findConversationIds() { + return this.target.findConversationIds(); + } + + @NotNull + @Override + public List findByConversationId(@NotNull String conversationId) { + return this.target.findByConversationId(conversationId).stream().map(this::decrypt).toList(); + } + + @Override + public void saveAll(@NotNull String conversationId, List messages) { + this.target.saveAll(conversationId, messages.stream().map(this::encrypt).collect(Collectors.toList())); + } + + @Override + public void deleteByConversationId(@NotNull String conversationId) { + this.target.deleteByConversationId(conversationId); + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java index b092de2d6da..703ed2c89bd 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java @@ -16,14 +16,15 @@ package org.springframework.ai.chat.messages; -import java.util.List; -import java.util.Map; -import java.util.Objects; - import org.springframework.ai.content.Media; import org.springframework.ai.content.MediaContent; +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import java.util.*; /** * Lets the generative know the content was generated as a response to the user. This role @@ -33,6 +34,7 @@ * * @author Mark Pollack * @author Christian Tzolov + * @author Josh Long * @since 1.0.0 */ public class AssistantMessage extends AbstractMessage implements MediaContent { @@ -104,4 +106,81 @@ public record ToolCall(String id, String type, String name, String arguments) { } + // + + public AssistantMessage copy() { + return new AssistantMessage.Builder().text(getText()) + .media(List.copyOf(getMedia())) + .metadata(Map.copyOf(getMetadata())) + .build(); + } + + public AssistantMessage.Builder mutate() { + return new AssistantMessage.Builder().text(getText()) + .media(List.copyOf(getMedia())) + .metadata(Map.copyOf(getMetadata())); + } + + public static AssistantMessage.Builder builder() { + return new AssistantMessage.Builder(); + } + + public static class Builder { + + @Nullable + private String textContent; + + @Nullable + private Resource resource; + + private List media = new ArrayList<>(); + + private List toolCalls = new ArrayList<>(); + + private Map metadata = new HashMap<>(); + + public AssistantMessage.Builder text(String textContent) { + this.textContent = textContent; + return this; + } + + public AssistantMessage.Builder toolCalls(List toolCalls) { + this.toolCalls = toolCalls; + return this; + } + + public AssistantMessage.Builder text(Resource resource) { + this.resource = resource; + return this; + } + + public AssistantMessage.Builder media(List media) { + this.media = media; + return this; + } + + public AssistantMessage.Builder media(@Nullable Media... media) { + if (media != null) { + this.media = Arrays.asList(media); + } + return this; + } + + public AssistantMessage.Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public AssistantMessage build() { + if (StringUtils.hasText(textContent) && resource != null) { + throw new IllegalArgumentException("textContent and resource cannot be set at the same time"); + } + else if (resource != null) { + this.textContent = MessageUtils.readResource(resource); + } + return new AssistantMessage(this.textContent, this.metadata, this.toolCalls, this.media); + } + + } + }