Skip to content

[WIP] PR for transparent encryption of values in a durable ChatMemoryRepository https://github.com/spring-projects/spring-ai/issues/2975 #3008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions spring-ai-model/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@
<optional>true</optional>
</dependency>

<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-config</artifactId>
<optional>true</optional>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
* An in-memory implementation of {@link ChatMemoryRepository}.
*
* @author Thomas Vitale
* @author Josh Long
* @since 1.0.0
*/
public final class InMemoryChatMemoryRepository implements ChatMemoryRepository {
public class InMemoryChatMemoryRepository implements ChatMemoryRepository {

Map<String, List<Message>> chatMemoryStore = new ConcurrentHashMap<>();
private final Map<String, List<Message>> chatMemoryStore = new ConcurrentHashMap<>();

@Override
public List<String> findConversationIds() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
* {@link SystemMessage} messages are preserved while evicting other types of messages.
*
* @author Thomas Vitale
* @author Josh Long
* @since 1.0.0
*/
public final class MessageWindowChatMemory implements ChatMemory {
public class MessageWindowChatMemory implements ChatMemory {

private static final int DEFAULT_MAX_MESSAGES = 200;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.springframework.ai.chat.memory.encryption;

import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.security.crypto.encrypt.TextEncryptor;

/**
* Uses a configured {@link TextEncryptor text encryptor} to encrypt values before writes,
* and decode those values from the read operations.
*
* @author Josh Long
*/
public class EncryptingChatMemoryBeanPostProcessor implements BeanPostProcessor {

private final TextEncryptor encryptor;

public EncryptingChatMemoryBeanPostProcessor(TextEncryptor encryptor) {
this.encryptor = encryptor;
}

@Override
public Object postProcessAfterInitialization(@NotNull Object bean, @NotNull String beanName) throws BeansException {

if (bean instanceof ChatMemoryRepository cmr && !(cmr instanceof EncryptingChatMemoryRepository)) {
return new EncryptingChatMemoryRepository(cmr, encryptor);
}
return BeanPostProcessor.super.postProcessAfterInitialization(bean, beanName);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package org.springframework.ai.chat.memory.encryption;

import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.security.crypto.encrypt.TextEncryptor;

import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
*
* Wraps {@link ChatMemoryRepository a ChatMemoryRepository}, encrypting and decrypting
* reads and writes respectively using a Spring Security {@link TextEncryptor text
* encryptor}.
*
* @author Josh Long
*/
public class EncryptingChatMemoryRepository implements ChatMemoryRepository {

private final ChatMemoryRepository target;

private final TextEncryptor textEncryptor;

public EncryptingChatMemoryRepository(ChatMemoryRepository target, TextEncryptor textEncryptor) {
this.target = target;
this.textEncryptor = textEncryptor;
}

private Message transform(Message message, Function<String, String> function) {

var transformedText = function.apply(message.getText());

// todo is there a case to be made that we should seal the message hierarchy?
if (message instanceof SystemMessage systemMessage) {
return systemMessage.mutate().text(transformedText).build();
}

if (message instanceof UserMessage userMessage) {
return userMessage.mutate().text(transformedText).build();
}

if (message instanceof AssistantMessage assistantMessage) {
return assistantMessage.mutate().text(transformedText).build();
}

return message;
}

private Message decrypt(Message message) {
return this.transform(message, this.textEncryptor::decrypt);
}

private Message encrypt(Message message) {
return this.transform(message, this.textEncryptor::encrypt);
}

@NotNull
@Override
public List<String> findConversationIds() {
return this.target.findConversationIds();
}

@NotNull
@Override
public List<Message> findByConversationId(@NotNull String conversationId) {
return this.target.findByConversationId(conversationId).stream().map(this::decrypt).toList();
}

@Override
public void saveAll(@NotNull String conversationId, List<Message> messages) {
this.target.saveAll(conversationId, messages.stream().map(this::encrypt).collect(Collectors.toList()));
}

@Override
public void deleteByConversationId(@NotNull String conversationId) {
this.target.deleteByConversationId(conversationId);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

package org.springframework.ai.chat.messages;

import java.util.List;
import java.util.Map;
import java.util.Objects;

import org.springframework.ai.content.Media;
import org.springframework.ai.content.MediaContent;
import org.springframework.core.io.Resource;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.*;

/**
* Lets the generative know the content was generated as a response to the user. This role
Expand All @@ -33,6 +34,7 @@
*
* @author Mark Pollack
* @author Christian Tzolov
* @author Josh Long
* @since 1.0.0
*/
public class AssistantMessage extends AbstractMessage implements MediaContent {
Expand Down Expand Up @@ -104,4 +106,81 @@ public record ToolCall(String id, String type, String name, String arguments) {

}

//

public AssistantMessage copy() {
return new AssistantMessage.Builder().text(getText())
.media(List.copyOf(getMedia()))
.metadata(Map.copyOf(getMetadata()))
.build();
}

public AssistantMessage.Builder mutate() {
return new AssistantMessage.Builder().text(getText())
.media(List.copyOf(getMedia()))
.metadata(Map.copyOf(getMetadata()));
}

public static AssistantMessage.Builder builder() {
return new AssistantMessage.Builder();
}

public static class Builder {

@Nullable
private String textContent;

@Nullable
private Resource resource;

private List<Media> media = new ArrayList<>();

private List<ToolCall> toolCalls = new ArrayList<>();

private Map<String, Object> metadata = new HashMap<>();

public AssistantMessage.Builder text(String textContent) {
this.textContent = textContent;
return this;
}

public AssistantMessage.Builder toolCalls(List<ToolCall> toolCalls) {
this.toolCalls = toolCalls;
return this;
}

public AssistantMessage.Builder text(Resource resource) {
this.resource = resource;
return this;
}

public AssistantMessage.Builder media(List<Media> media) {
this.media = media;
return this;
}

public AssistantMessage.Builder media(@Nullable Media... media) {
if (media != null) {
this.media = Arrays.asList(media);
}
return this;
}

public AssistantMessage.Builder metadata(Map<String, Object> metadata) {
this.metadata = metadata;
return this;
}

public AssistantMessage build() {
if (StringUtils.hasText(textContent) && resource != null) {
throw new IllegalArgumentException("textContent and resource cannot be set at the same time");
}
else if (resource != null) {
this.textContent = MessageUtils.readResource(resource);
}
return new AssistantMessage(this.textContent, this.metadata, this.toolCalls, this.media);
}

}

}