diff --git a/.stats.yml b/.stats.yml index 1f1a1736..b80d385d 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 80 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/openai%2Fopenai-4bce8217a697c729ac98046d4caf2c9e826b54c427fb0ab4f98e549a2e0ce31c.yml openapi_spec_hash: 7996d2c34cc44fe2ce9ffe93c0ab774e -config_hash: 178ba1bfb1237bf6b94abb3408072aa7 +config_hash: 578c5bff4208d560c0c280f13324409f diff --git a/README.md b/README.md index f1d9f7ba..b39d74a4 100644 --- a/README.md +++ b/README.md @@ -292,7 +292,7 @@ OpenAIClient client = OpenAIOkHttpClient.builder() The SDK provides conveniences for streamed chat completions. A [`ChatCompletionAccumulator`](openai-java-core/src/main/kotlin/com/openai/helpers/ChatCompletionAccumulator.kt) -can record the stream of chat completion chunks in the response as they are processed and accumulate +can record the stream of chat completion chunks in the response as they are processed and accumulate a [`ChatCompletion`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletion.kt) object similar to that which would have been returned by the non-streaming API. @@ -340,6 +340,53 @@ client.chat() ChatCompletion chatCompletion = chatCompletionAccumulator.chatCompletion(); ``` +The SDK provides conveniences for streamed responses. A +[`ResponseAccumulator`](openai-java-core/src/main/kotlin/com/openai/helpers/ResponseAccumulator.kt) +can record the stream of response events as they are processed and accumulate a +[`Response`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletion.kt) +object similar to that which would have been returned by the non-streaming API. + +For a synchronous response add a +[`Stream.peek()`](https://docs.oracle.com/javase/8/docs/api/java/util/stream/Stream.html#peek-java.util.function.Consumer-) +call to the stream pipeline to accumulate each event: + +```java +import com.openai.core.http.StreamResponse; +import com.openai.helpers.ResponseAccumulator; +import com.openai.models.responses.Response; +import com.openai.models.responses.ResponseStreamEvent; + +ResponseAccumulator responseAccumulator = ResponseAccumulator.create(); + +try (StreamResponse streamResponse = + client.responses().createStreaming(createParams)) { + streamResponse.stream() + .peek(responseAccumulator::accumulate) + .flatMap(event -> event.outputTextDelta().stream()) + .forEach(textEvent -> System.out.print(textEvent.delta())); +} + +Response response = responseAccumulator.response(); +``` + +For an asynchronous response, add the `ResponseAccumulator` to the `subscribe()` call: + +```java +import com.openai.helpers.ResponseAccumulator; +import com.openai.models.responses.Response; + +ResponseAccumulator responseAccumulator = ResponseAccumulator.create(); + +client.responses() + .createStreaming(createParams) + .subscribe(event -> responseAccumulator.accumulate(event) + .outputTextDelta().ifPresent(textEvent -> System.out.print(textEvent.delta()))) + .onCompleteFuture() + .join(); + +Response response = responseAccumulator.response(); +``` + ## File uploads The SDK defines methods that accept files. diff --git a/openai-java-core/src/main/kotlin/com/openai/core/ObjectMappers.kt b/openai-java-core/src/main/kotlin/com/openai/core/ObjectMappers.kt index 7be25df3..24bec380 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/ObjectMappers.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/ObjectMappers.kt @@ -4,12 +4,16 @@ package com.openai.core import com.fasterxml.jackson.annotation.JsonInclude import com.fasterxml.jackson.core.JsonGenerator +import com.fasterxml.jackson.core.JsonParseException +import com.fasterxml.jackson.core.JsonParser +import com.fasterxml.jackson.databind.DeserializationContext import com.fasterxml.jackson.databind.DeserializationFeature import com.fasterxml.jackson.databind.MapperFeature import com.fasterxml.jackson.databind.SerializationFeature import com.fasterxml.jackson.databind.SerializerProvider import com.fasterxml.jackson.databind.cfg.CoercionAction import com.fasterxml.jackson.databind.cfg.CoercionInputShape +import com.fasterxml.jackson.databind.deser.std.StdDeserializer import com.fasterxml.jackson.databind.json.JsonMapper import com.fasterxml.jackson.databind.module.SimpleModule import com.fasterxml.jackson.databind.type.LogicalType @@ -17,13 +21,23 @@ import com.fasterxml.jackson.datatype.jdk8.Jdk8Module import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule import com.fasterxml.jackson.module.kotlin.kotlinModule import java.io.InputStream +import java.time.DateTimeException +import java.time.LocalDate +import java.time.LocalDateTime +import java.time.ZonedDateTime +import java.time.format.DateTimeFormatter +import java.time.temporal.ChronoField fun jsonMapper(): JsonMapper = JsonMapper.builder() .addModule(kotlinModule()) .addModule(Jdk8Module()) .addModule(JavaTimeModule()) - .addModule(SimpleModule().addSerializer(InputStreamJsonSerializer)) + .addModule( + SimpleModule() + .addSerializer(InputStreamSerializer) + .addDeserializer(LocalDateTime::class.java, LenientLocalDateTimeDeserializer()) + ) .withCoercionConfig(LogicalType.Boolean) { it.setCoercion(CoercionInputShape.Integer, CoercionAction.Fail) .setCoercion(CoercionInputShape.Float, CoercionAction.Fail) @@ -91,7 +105,10 @@ fun jsonMapper(): JsonMapper = .disable(MapperFeature.AUTO_DETECT_SETTERS) .build() -private object InputStreamJsonSerializer : BaseSerializer(InputStream::class) { +/** A serializer that serializes [InputStream] to bytes. */ +private object InputStreamSerializer : BaseSerializer(InputStream::class) { + + private fun readResolve(): Any = InputStreamSerializer override fun serialize( value: InputStream?, @@ -105,3 +122,46 @@ private object InputStreamJsonSerializer : BaseSerializer(InputStre } } } + +/** + * A deserializer that can deserialize [LocalDateTime] from datetimes, dates, and zoned datetimes. + */ +private class LenientLocalDateTimeDeserializer : + StdDeserializer(LocalDateTime::class.java) { + + companion object { + + private val DATE_TIME_FORMATTERS = + listOf( + DateTimeFormatter.ISO_LOCAL_DATE_TIME, + DateTimeFormatter.ISO_LOCAL_DATE, + DateTimeFormatter.ISO_ZONED_DATE_TIME, + ) + } + + override fun logicalType(): LogicalType = LogicalType.DateTime + + override fun deserialize(p: JsonParser, context: DeserializationContext?): LocalDateTime { + val exceptions = mutableListOf() + + for (formatter in DATE_TIME_FORMATTERS) { + try { + val temporal = formatter.parse(p.text) + + return when { + !temporal.isSupported(ChronoField.HOUR_OF_DAY) -> + LocalDate.from(temporal).atStartOfDay() + !temporal.isSupported(ChronoField.OFFSET_SECONDS) -> + LocalDateTime.from(temporal) + else -> ZonedDateTime.from(temporal).toLocalDateTime() + } + } catch (e: DateTimeException) { + exceptions.add(e) + } + } + + throw JsonParseException(p, "Cannot parse `LocalDateTime` from value: ${p.text}").apply { + exceptions.forEach { addSuppressed(it) } + } + } +} diff --git a/openai-java-core/src/main/kotlin/com/openai/helpers/ResponseAccumulator.kt b/openai-java-core/src/main/kotlin/com/openai/helpers/ResponseAccumulator.kt new file mode 100644 index 00000000..053ce6bc --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/helpers/ResponseAccumulator.kt @@ -0,0 +1,208 @@ +package com.openai.helpers + +import com.openai.models.responses.Response +import com.openai.models.responses.ResponseAudioDeltaEvent +import com.openai.models.responses.ResponseAudioDoneEvent +import com.openai.models.responses.ResponseAudioTranscriptDeltaEvent +import com.openai.models.responses.ResponseAudioTranscriptDoneEvent +import com.openai.models.responses.ResponseCodeInterpreterCallCodeDeltaEvent +import com.openai.models.responses.ResponseCodeInterpreterCallCodeDoneEvent +import com.openai.models.responses.ResponseCodeInterpreterCallCompletedEvent +import com.openai.models.responses.ResponseCodeInterpreterCallInProgressEvent +import com.openai.models.responses.ResponseCodeInterpreterCallInterpretingEvent +import com.openai.models.responses.ResponseCompletedEvent +import com.openai.models.responses.ResponseContentPartAddedEvent +import com.openai.models.responses.ResponseContentPartDoneEvent +import com.openai.models.responses.ResponseCreatedEvent +import com.openai.models.responses.ResponseErrorEvent +import com.openai.models.responses.ResponseFailedEvent +import com.openai.models.responses.ResponseFileSearchCallCompletedEvent +import com.openai.models.responses.ResponseFileSearchCallInProgressEvent +import com.openai.models.responses.ResponseFileSearchCallSearchingEvent +import com.openai.models.responses.ResponseFunctionCallArgumentsDeltaEvent +import com.openai.models.responses.ResponseFunctionCallArgumentsDoneEvent +import com.openai.models.responses.ResponseInProgressEvent +import com.openai.models.responses.ResponseIncompleteEvent +import com.openai.models.responses.ResponseOutputItemAddedEvent +import com.openai.models.responses.ResponseOutputItemDoneEvent +import com.openai.models.responses.ResponseRefusalDeltaEvent +import com.openai.models.responses.ResponseRefusalDoneEvent +import com.openai.models.responses.ResponseStreamEvent +import com.openai.models.responses.ResponseTextAnnotationDeltaEvent +import com.openai.models.responses.ResponseTextDeltaEvent +import com.openai.models.responses.ResponseTextDoneEvent +import com.openai.models.responses.ResponseWebSearchCallCompletedEvent +import com.openai.models.responses.ResponseWebSearchCallInProgressEvent +import com.openai.models.responses.ResponseWebSearchCallSearchingEvent + +/** + * An accumulator that constructs a [Response] from a sequence of streamed events. Pass all events + * to [accumulate] and then call [response] to get the final accumulated response. The final + * `Response` will be similar to what would have been received had the non-streaming API been used. + * + * A [ResponseAccumulator] may only be used to accumulate _one_ response. To accumulate another + * response, create another instance of `ResponseAccumulator`. + */ +class ResponseAccumulator private constructor() { + + /** + * The response accumulated from the event stream. This is set when a terminal event is + * accumulated. That single event carries all the response details. + */ + private var response: Response? = null + + companion object { + @JvmStatic fun create() = ResponseAccumulator() + } + + /** + * Gets the final accumulated response. Until the last event has been accumulated, a [Response] + * will not be available. Wait until all events have been handled by [accumulate] before calling + * this method. + * + * @throws IllegalStateException If called before the stream has been completed. + */ + fun response() = checkNotNull(response) { "Completed response is not yet received." } + + /** + * Accumulates a streamed event and uses it to construct a [Response]. When all events have been + * accumulated, the response can be retrieved by calling [response]. + * + * @return The given [event] for convenience, such as when chaining method calls. + * @throws IllegalStateException If [accumulate] is called again after the last event has been + * accumulated. A [ResponseAccumulator] can only be used to accumulate a single [Response]. + */ + fun accumulate(event: ResponseStreamEvent): ResponseStreamEvent { + check(response == null) { "Response has already been completed." } + + event.accept( + object : ResponseStreamEvent.Visitor { + // -------------------------------------------------------------------------------- + // The following events _all_ have a `Response` property. + + override fun visitCreated(created: ResponseCreatedEvent) { + // TODO: Taking not action here on the assumption that there is no need to store + // the initial `Response` (devoid of any content), as it will be replaced + // later by one of the "terminal" events. OTOH, this could be useful if the + // events stop suddenly before any further response details can be recorded. + } + + override fun visitInProgress(inProgress: ResponseInProgressEvent) { + // TODO: Taking no action here on the assumption that this is just some sort of + // "keep-alive" event that carries no new data that needs to be accumulated. + // OTOH, if the events stop suddenly, this could be used as a "partial" + // response, or an ongoing "story so far". + } + + override fun visitCompleted(completed: ResponseCompletedEvent) { + response = completed.response() + } + + override fun visitFailed(failed: ResponseFailedEvent) { + // TODO: Confirm that this is a "terminal" event and will occur _instead of_ + // `ResponseCompletedEvent`. + response = failed.response() + } + + override fun visitIncomplete(incomplete: ResponseIncompleteEvent) { + // TODO: Confirm that this is a "terminal" event and will occur _instead of_ + // `ResponseCompletedEvent`. + response = incomplete.response() + } + + // -------------------------------------------------------------------------------- + // The following events do _not_ have a `Response` property. + + override fun visitError(error: ResponseErrorEvent) {} + + override fun visitOutputItemAdded(outputItemAdded: ResponseOutputItemAddedEvent) {} + + override fun visitOutputItemDone(outputItemDone: ResponseOutputItemDoneEvent) {} + + override fun visitContentPartAdded( + contentPartAdded: ResponseContentPartAddedEvent + ) {} + + override fun visitContentPartDone(contentPartDone: ResponseContentPartDoneEvent) {} + + override fun visitOutputTextDelta(outputTextDelta: ResponseTextDeltaEvent) {} + + override fun visitOutputTextAnnotationAdded( + outputTextAnnotationAdded: ResponseTextAnnotationDeltaEvent + ) {} + + override fun visitOutputTextDone(outputTextDone: ResponseTextDoneEvent) {} + + override fun visitRefusalDelta(refusalDelta: ResponseRefusalDeltaEvent) {} + + override fun visitRefusalDone(refusalDone: ResponseRefusalDoneEvent) {} + + override fun visitFunctionCallArgumentsDelta( + functionCallArgumentsDelta: ResponseFunctionCallArgumentsDeltaEvent + ) {} + + override fun visitFunctionCallArgumentsDone( + functionCallArgumentsDone: ResponseFunctionCallArgumentsDoneEvent + ) {} + + override fun visitFileSearchCallInProgress( + fileSearchCallInProgress: ResponseFileSearchCallInProgressEvent + ) {} + + override fun visitFileSearchCallSearching( + fileSearchCallSearching: ResponseFileSearchCallSearchingEvent + ) {} + + override fun visitFileSearchCallCompleted( + fileSearchCallCompleted: ResponseFileSearchCallCompletedEvent + ) {} + + override fun visitWebSearchCallInProgress( + webSearchCallInProgress: ResponseWebSearchCallInProgressEvent + ) {} + + override fun visitWebSearchCallSearching( + webSearchCallSearching: ResponseWebSearchCallSearchingEvent + ) {} + + override fun visitWebSearchCallCompleted( + webSearchCallCompleted: ResponseWebSearchCallCompletedEvent + ) {} + + override fun visitAudioDelta(audioDelta: ResponseAudioDeltaEvent) {} + + override fun visitAudioDone(audioDone: ResponseAudioDoneEvent) {} + + override fun visitAudioTranscriptDelta( + audioTranscriptDelta: ResponseAudioTranscriptDeltaEvent + ) {} + + override fun visitAudioTranscriptDone( + audioTranscriptDone: ResponseAudioTranscriptDoneEvent + ) {} + + override fun visitCodeInterpreterCallCodeDelta( + codeInterpreterCallCodeDelta: ResponseCodeInterpreterCallCodeDeltaEvent + ) {} + + override fun visitCodeInterpreterCallCodeDone( + codeInterpreterCallCodeDone: ResponseCodeInterpreterCallCodeDoneEvent + ) {} + + override fun visitCodeInterpreterCallInProgress( + codeInterpreterCallInProgress: ResponseCodeInterpreterCallInProgressEvent + ) {} + + override fun visitCodeInterpreterCallInterpreting( + codeInterpreterCallInterpreting: ResponseCodeInterpreterCallInterpretingEvent + ) {} + + override fun visitCodeInterpreterCallCompleted( + codeInterpreterCallCompleted: ResponseCodeInterpreterCallCompletedEvent + ) {} + } + ) + + return event + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/core/ObjectMappersTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/ObjectMappersTest.kt index 17fe45c3..2dbf1a4c 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/ObjectMappersTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/ObjectMappersTest.kt @@ -2,10 +2,15 @@ package com.openai.core import com.fasterxml.jackson.annotation.JsonProperty import com.fasterxml.jackson.databind.exc.MismatchedInputException +import com.fasterxml.jackson.module.kotlin.readValue +import java.time.LocalDateTime import kotlin.reflect.KClass import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.catchThrowable import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertDoesNotThrow +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.EnumSource import org.junitpioneer.jupiter.cartesian.CartesianTest internal class ObjectMappersTest { @@ -78,4 +83,20 @@ internal class ObjectMappersTest { assertThat(e).isInstanceOf(MismatchedInputException::class.java) } } + + enum class LenientLocalDateTimeTestCase(val string: String) { + DATE("1998-04-21"), + DATE_TIME("1998-04-21T04:00:00"), + ZONED_DATE_TIME_1("1998-04-21T04:00:00+03:00"), + ZONED_DATE_TIME_2("1998-04-21T04:00:00Z"), + } + + @ParameterizedTest + @EnumSource + fun readLocalDateTime_lenient(testCase: LenientLocalDateTimeTestCase) { + val jsonMapper = jsonMapper() + val json = jsonMapper.writeValueAsString(testCase.string) + + assertDoesNotThrow { jsonMapper().readValue(json) } + } } diff --git a/openai-java-core/src/test/kotlin/com/openai/helpers/ResponseAccumulatorTest.kt b/openai-java-core/src/test/kotlin/com/openai/helpers/ResponseAccumulatorTest.kt new file mode 100644 index 00000000..b6f402f6 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/helpers/ResponseAccumulatorTest.kt @@ -0,0 +1,141 @@ +package com.openai.helpers + +import com.openai.core.JsonNull +import com.openai.models.ResponsesModel +import com.openai.models.responses.Response +import com.openai.models.responses.ResponseCompletedEvent +import com.openai.models.responses.ResponseCreatedEvent +import com.openai.models.responses.ResponseFailedEvent +import com.openai.models.responses.ResponseInProgressEvent +import com.openai.models.responses.ResponseIncompleteEvent +import com.openai.models.responses.ResponseOutputItem +import com.openai.models.responses.ResponseOutputMessage +import com.openai.models.responses.ResponseOutputText +import com.openai.models.responses.ResponseStreamEvent +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatNoException +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.jupiter.api.Test + +internal class ResponseAccumulatorTest { + + @Test + fun responseBeforeAccumulation() { + val accumulator = ResponseAccumulator.create() + + assertThatThrownBy { accumulator.response() } + .isExactlyInstanceOf(IllegalStateException::class.java) + .hasMessage("Completed response is not yet received.") + } + + @Test + fun responseAfterAccumulation() { + val accumulator = ResponseAccumulator.create() + + accumulator.accumulate(ResponseStreamEvent.ofCompleted(responseCompletedEvent())) + + assertThatNoException().isThrownBy { accumulator.response() } + assertThat(accumulator.response().id()).isEqualTo("response-id") + } + + @Test + fun accumulateAfterCompleted() { + val accumulator = ResponseAccumulator.create() + + accumulator.accumulate(ResponseStreamEvent.ofCompleted(responseCompletedEvent())) + + assertThatThrownBy { + accumulator.accumulate(ResponseStreamEvent.ofCompleted(responseCompletedEvent())) + } + .isExactlyInstanceOf(IllegalStateException::class.java) + .hasMessage("Response has already been completed.") + } + + @Test + fun accumulateUntilCompleted() { + val accumulator = ResponseAccumulator.create() + + accumulator.accumulate(ResponseStreamEvent.ofCreated(responseCreatedEvent())) + accumulator.accumulate(ResponseStreamEvent.ofInProgress(responseInProgressEvent())) + accumulator.accumulate(ResponseStreamEvent.ofInProgress(responseInProgressEvent())) + accumulator.accumulate(ResponseStreamEvent.ofInProgress(responseInProgressEvent())) + accumulator.accumulate(ResponseStreamEvent.ofCompleted(responseCompletedEvent())) + + val response = accumulator.response() + + assertThat(response.id()).isEqualTo("response-id") + } + + @Test + fun accumulateUntilIncomplete() { + val accumulator = ResponseAccumulator.create() + + accumulator.accumulate(ResponseStreamEvent.ofCreated(responseCreatedEvent())) + accumulator.accumulate(ResponseStreamEvent.ofInProgress(responseInProgressEvent())) + accumulator.accumulate(ResponseStreamEvent.ofInProgress(responseInProgressEvent())) + accumulator.accumulate(ResponseStreamEvent.ofInProgress(responseInProgressEvent())) + accumulator.accumulate(ResponseStreamEvent.ofIncomplete(responseIncompleteEvent())) + + val response = accumulator.response() + + assertThat(response.id()).isEqualTo("response-id") + } + + @Test + fun accumulateUntilFailed() { + val accumulator = ResponseAccumulator.create() + + accumulator.accumulate(ResponseStreamEvent.ofCreated(responseCreatedEvent())) + accumulator.accumulate(ResponseStreamEvent.ofInProgress(responseInProgressEvent())) + accumulator.accumulate(ResponseStreamEvent.ofInProgress(responseInProgressEvent())) + accumulator.accumulate(ResponseStreamEvent.ofInProgress(responseInProgressEvent())) + accumulator.accumulate(ResponseStreamEvent.ofFailed(responseFailedEvent())) + + val response = accumulator.response() + + assertThat(response.id()).isEqualTo("response-id") + } + + private fun responseCreatedEvent() = ResponseCreatedEvent.builder().response(response()).build() + + private fun responseInProgressEvent() = + ResponseInProgressEvent.builder().response(response()).build() + + private fun responseCompletedEvent() = + ResponseCompletedEvent.builder().response(response()).build() + + private fun responseFailedEvent() = ResponseFailedEvent.builder().response(response()).build() + + private fun responseIncompleteEvent() = + ResponseIncompleteEvent.builder().response(response()).build() + + private fun response() = + Response.builder() + .id("response-id") + .createdAt(System.currentTimeMillis() / 1_000.0) + .error(null) + .incompleteDetails(null) + .instructions(null) + .metadata(null) + .model(ResponsesModel.UnionMember2.O1_PRO) + .addOutput(responseOutputItemOfMessage()) + .parallelToolCalls(false) + .temperature(null) + .toolChoice(JsonNull.of()) + .tools(listOf()) + .topP(null) + .build() + + private fun responseOutputItemOfMessage() = + ResponseOutputItem.ofMessage(responseOutputMessage()) + + private fun responseOutputMessage() = + ResponseOutputMessage.builder() + .id("message-id") + .addContent(ResponseOutputMessage.Content.ofOutputText(responseOutputText())) + .status(ResponseOutputMessage.Status.COMPLETED) + .build() + + private fun responseOutputText() = + ResponseOutputText.builder().text("Hello World").annotations(listOf()).build() +}