diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java index 55c622ff7..c6fe6f910 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -19,6 +19,7 @@ import io.weaviate.client6.v1.api.collections.generative.AzureOpenAiGenerative; import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; import io.weaviate.client6.v1.api.collections.generative.DatabricksGenerative; +import io.weaviate.client6.v1.api.collections.generative.DeepseekGenerative; import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; import io.weaviate.client6.v1.api.collections.generative.FriendliaiGenerative; import io.weaviate.client6.v1.api.collections.generative.GoogleGenerative; @@ -38,6 +39,7 @@ public enum Kind implements JsonEnum { ANTHROPIC("generative-anthropic"), COHERE("generative-cohere"), DATABRICKS("generative-databricks"), + DEEPSEEK("generative-deepseek"), FRIENDLIAI("generative-friendliai"), GOOGLE("generative-google"), MISTRAL("generative-mistral"), @@ -171,6 +173,22 @@ public static Generative databricks(String endpoint, return DatabricksGenerative.of(endpoint, fn); } + /** + * Configure a default {@code generative-deepseek} module. + */ + public static Generative deepseek() { + return DeepseekGenerative.of(); + } + + /** + * Configure a {@code generative-deepseek} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative deepseek(Function> fn) { + return DeepseekGenerative.of(fn); + } + /** Configure a default {@code generative-frienliai} module. */ public static Generative frienliai() { return FriendliaiGenerative.of(); @@ -384,6 +402,21 @@ default DatabricksGenerative asDatabricks() { return _as(Generative.Kind.DATABRICKS); } + /** Is this a {@code generative-deepseek} provider? */ + default boolean isDeepseek() { + return _is(Generative.Kind.DEEPSEEK); + } + + /** + * Get as {@link DeepseekGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-deepseek}. + */ + default DeepseekGenerative asDeepseek() { + return _as(Generative.Kind.DEEPSEEK); + } + /** Is this a {@code generative-friendliai} provider? */ default boolean isFriendliai() { return _is(Generative.Kind.FRIENDLIAI); @@ -520,6 +553,7 @@ private final void init(Gson gson) { addAdapter(gson, Generative.Kind.AWS, AwsGenerative.class); addAdapter(gson, Generative.Kind.COHERE, CohereGenerative.class); addAdapter(gson, Generative.Kind.DATABRICKS, DatabricksGenerative.class); + addAdapter(gson, Generative.Kind.DEEPSEEK, DeepseekGenerative.class); addAdapter(gson, Generative.Kind.GOOGLE, GoogleGenerative.class); addAdapter(gson, Generative.Kind.FRIENDLIAI, FriendliaiGenerative.class); addAdapter(gson, Generative.Kind.MISTRAL, MistralGenerative.class); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeProvider.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeProvider.java index 61085180f..e13321281 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeProvider.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeProvider.java @@ -8,6 +8,7 @@ import io.weaviate.client6.v1.api.collections.generative.AzureOpenAiGenerative; import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; import io.weaviate.client6.v1.api.collections.generative.DatabricksGenerative; +import io.weaviate.client6.v1.api.collections.generative.DeepseekGenerative; import io.weaviate.client6.v1.api.collections.generative.FriendliaiGenerative; import io.weaviate.client6.v1.api.collections.generative.GoogleGenerative; import io.weaviate.client6.v1.api.collections.generative.MistralGenerative; @@ -89,6 +90,16 @@ public static GenerativeProvider databricks( return DatabricksGenerative.Provider.of(fn); } + /** + * Configure {@code generative-deepseek} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static GenerativeProvider deepseek( + Function> fn) { + return DeepseekGenerative.Provider.of(fn); + } + /** * Configure {@code generative-friendliai} as a dynamic provider. * diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DeepseekGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DeepseekGenerative.java new file mode 100644 index 000000000..55e3cb095 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DeepseekGenerative.java @@ -0,0 +1,242 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.GenerativeProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record DeepseekGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("model") String model, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature, + @SerializedName("frequencyPenalty") Float frequencyPenalty, + @SerializedName("presencePenalty") Float presencePenalty, + @SerializedName("topP") Float topP, + @SerializedName("stop") List stopSequences) implements Generative { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.DEEPSEEK; + } + + @Override + public Object _self() { + return this; + } + + public static DeepseekGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static DeepseekGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public DeepseekGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.maxTokens, + builder.temperature, + builder.frequencyPenalty, + builder.presencePenalty, + builder.topP, + builder.stopSequences); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Float temperature; + private Integer maxTokens; + private Float frequencyPenalty; + private Float presencePenalty; + private Float topP; + private List stopSequences = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder stopSequences(String... values) { + return stopSequences(Arrays.asList(values)); + } + + public Builder stopSequences(List values) { + this.stopSequences.addAll(values); + return this; + } + + @Override + public DeepseekGenerative build() { + return new DeepseekGenerative(this); + } + } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + } + + public static record Provider( + String baseUrl, + String model, + Integer maxTokens, + Float temperature, + Float frequencyPenalty, + Float presencePenalty, + Float topP, + List stopSequences) implements GenerativeProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeDeepseek.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (topP != null) { + provider.setTopP(topP); + } + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + if (stopSequences != null) { + provider.setStop(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + req.setDeepseek(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.maxTokens, + builder.temperature, + builder.frequencyPenalty, + builder.presencePenalty, + builder.topP, + builder.stopSequences); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Float temperature; + private Integer maxTokens; + private Float frequencyPenalty; + private Float presencePenalty; + private Float topP; + private List stopSequences = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder stop(String... values) { + return stop(Arrays.asList(values)); + } + + public Builder stop(List values) { + this.stopSequences.addAll(values); + return this; + } + + @Override + public DeepseekGenerative.Provider build() { + return new DeepseekGenerative.Provider(this); + } + } + } +}