Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ public record GoogleGenerative(
@SerializedName("maxOutputTokens") Integer maxTokens,
@SerializedName("topK") Integer topK,
@SerializedName("topP") Float topP,
@SerializedName("temperature") Float temperature) implements Generative {
@SerializedName("temperature") Float temperature,
@SerializedName("location") String location) implements Generative {

@Override
public Kind _kind() {
Expand Down Expand Up @@ -63,7 +64,8 @@ public GoogleGenerative(Builder builder) {
builder.maxTokens,
builder.topK,
builder.topP,
builder.temperature);
builder.temperature,
builder.location);
}

public abstract static class Builder implements ObjectBuilder<GoogleGenerative> {
Expand All @@ -78,6 +80,7 @@ public abstract static class Builder implements ObjectBuilder<GoogleGenerative>
private Integer topK;
private Float topP;
private Float temperature;
private String location;

public Builder(String apiEndpoint, String projectId) {
this.projectId = projectId;
Expand Down Expand Up @@ -141,6 +144,12 @@ public Builder temperature(float temperature) {
return this;
}

/** Defaults to {@code us-central} on the server. */
public Builder location(String location) {
this.location = location;
return this;
}

@Override
public GoogleGenerative build() {
return new GoogleGenerative(this);
Expand Down Expand Up @@ -189,6 +198,7 @@ public static record Provider(
String projectId,
String endpointId,
String region,
String location,
List<String> stopSequences,
List<String> images,
List<String> imageProperties) implements GenerativeProvider {
Expand Down Expand Up @@ -245,6 +255,9 @@ public void appendTo(
provider.setStopSequences(WeaviateProtoBase.TextArray.newBuilder()
.addAllValues(stopSequences));
}
if (location != null) {
provider.setLocation(location);
}
req.setGoogle(provider);
}

Expand All @@ -261,6 +274,7 @@ public Provider(Builder builder) {
builder.projectId,
builder.endpointId,
builder.region,
builder.location,
builder.stopSequences,
builder.images,
builder.imageProperties);
Expand All @@ -279,6 +293,7 @@ public abstract static class Builder implements ObjectBuilder<GoogleGenerative.P
private Float presencePenalty;
private String endpointId;
private String region;
private String location;
private final List<String> stopSequences = new ArrayList<>();
private final List<String> images = new ArrayList<>();
private final List<String> imageProperties = new ArrayList<>();
Expand Down Expand Up @@ -353,6 +368,11 @@ public Builder region(String region) {
return this;
}

public Builder location(String location) {
this.location = location;
return this;
}

public Builder images(String... images) {
return images(Arrays.asList(images));
}
Expand Down
Loading
Loading