diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d9692a8..268a5e2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -23,7 +23,7 @@ jobs: run: | python -m venv venv source venv/bin/activate - pip install componentize-py==0.22.0 http-router==4.1.2 build==1.4.2 + pip install componentize-py==0.22.0 http-router==4.1.2 build==1.4.2 mypy==1.13 python -m build pip install dist/spin_sdk-4.0.0-py3-none-any.whl bash run_tests.sh diff --git a/examples/external-lib-example/app.py b/examples/external-lib-example/app.py index 8bfcbf9..3808d64 100644 --- a/examples/external-lib-example/app.py +++ b/examples/external-lib-example/app.py @@ -1,3 +1,4 @@ +from typing import cast from spin_sdk import http from spin_sdk.http import Request, Response import re @@ -47,7 +48,7 @@ async def handle_request(self, request: Request) -> Response: uri = urlparse(request.uri) try: handler = router(uri.path, request.method) - return handler.target(uri, request) + return cast(Response, handler.target(uri, request)) except exceptions.NotFoundError: return Response(404, {}, None) diff --git a/examples/redis-trigger/app.py b/examples/redis-trigger/app.py index 30bafbd..162fc4b 100644 --- a/examples/redis-trigger/app.py +++ b/examples/redis-trigger/app.py @@ -1,5 +1,5 @@ from spin_sdk.wit import exports class SpinRedisInboundRedis300(exports.SpinRedisInboundRedis300): - async def handle_message(self, message: bytes): + async def handle_message(self, message: bytes) -> None: print(message) diff --git a/examples/spin-kv/app.py b/examples/spin-kv/app.py index af0b1f6..ba0d974 100644 --- a/examples/spin-kv/app.py +++ b/examples/spin-kv/app.py @@ -1,3 +1,7 @@ +from typing import TypeVar, Tuple, List +from componentize_py_types import Result, Err +from componentize_py_async_support.streams import StreamReader +from componentize_py_async_support.futures import FutureReader from spin_sdk import http, key_value from spin_sdk.http import Request, Response from spin_sdk.key_value import Store @@ -9,7 +13,7 @@ async def handle_request(self, request: Request) -> Response: print(await get_keys(a)) print(await a.exists("test")) print(await a.get("test")) - print(await a.delete("test")) + await a.delete("test") print(await get_keys(a)) return Response( @@ -18,14 +22,10 @@ async def handle_request(self, request: Request) -> Response: bytes("Hello from Python!", "utf-8") ) -async def get_keys(Store) -> list[str]: - stream, future = await Store.get_keys() - keys = [] - - while True: - batch = await stream.read(max_count=100) - if not batch: - break - keys.extend(batch) - - return keys \ No newline at end of file +async def get_keys(store: Store) -> list[str]: + stream, future = await store.get_keys() + with stream, future: + keys = [] + while not stream.writer_dropped: + keys += await stream.read(max_count=100) + return keys diff --git a/examples/spin-postgres/app.py b/examples/spin-postgres/app.py index ff0bae9..bd70537 100644 --- a/examples/spin-postgres/app.py +++ b/examples/spin-postgres/app.py @@ -1,14 +1,15 @@ from spin_sdk import http, postgres from spin_sdk.http import Request, Response +from spin_sdk.postgres import RowSet, DbValue -def format_value(db_value) -> str: +def format_value(db_value: DbValue) -> str: if hasattr(db_value, "value"): return str(db_value.value) return "NULL" -def format_rowset(rowset) -> str: +def format_rowset(rowset: RowSet) -> str: lines = [] col_names = [col.name for col in rowset.columns] lines.append(" | ".join(col_names)) diff --git a/run_tests.sh b/run_tests.sh index 90dd623..b616d42 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -2,6 +2,10 @@ source venv/bin/activate +# First, install any example-specific dependencies (common dependencies such as +# `componentize-py`, `spin-sdk`, and `mypy` are assumed to have been installed +# in the virtual environment). + if [ ! -d examples/matrix-math/numpy ] then (cd examples/matrix-math \ @@ -9,14 +13,32 @@ then && tar xf numpy-wasi.tar.gz) fi +# Next, run MyPy on all the examples + +for example in examples/* +do + echo "linting $example" + if [ $example = "examples/matrix-math" ] + then + # NumPy fails linting as of this writing, so we skip it + extra_option="--follow-imports silent" + else + unset extra_option + fi + export MYPYPATH=$(pwd)/src + (cd $example && mypy --strict $extra_option -m app) || exit 1 +done + +# Next, build all the examples + for example in examples/* do echo "building $example" (cd $example && spin build) || exit 1 done +# Finally, run some of the examples and test that they behave as expected -# run trivial examples for example in examples/hello examples/external-lib-example examples/spin-kv examples/spin-variables do pushd $example diff --git a/src/spin_sdk/http/__init__.py b/src/spin_sdk/http/__init__.py index b47b402..943cbdf 100644 --- a/src/spin_sdk/http/__init__.py +++ b/src/spin_sdk/http/__init__.py @@ -101,7 +101,7 @@ async def handle(self, request: WasiRequest) -> WasiResponse: simple_response.headers['content-length'] = str(content_length) tx, rx = wit.byte_stream() - componentize_py_async_support.spawn(copy(simple_response.body, tx)) + componentize_py_async_support.spawn(_copy(simple_response.body, tx)) response = WasiResponse.new(Fields.from_list(list(map( lambda pair: (pair[0], bytes(pair[1], "utf-8")), simple_response.headers.items() @@ -159,7 +159,7 @@ async def send(request: Request) -> Response: content_length = len(request.body) if request.body is not None else 0 # Make a copy rather than mutate in place, since the caller might not # expect us to mutate it: - headers_dict = headers_dict.copy() + headers_dict = dict(headers_dict) headers_dict['content-length'] = str(content_length) headers = list(map( @@ -168,12 +168,12 @@ async def send(request: Request) -> Response: )) tx, rx = wit.byte_stream() - componentize_py_async_support.spawn(copy(request.body, tx)) + componentize_py_async_support.spawn(_copy(request.body, tx)) outgoing_request = WasiRequest.new(Fields.from_list(headers), rx, _trailers_future(), None)[0] outgoing_request.set_method(method) outgoing_request.set_scheme(scheme) if url_parsed.netloc == '': - if scheme == "http": + if isinstance(scheme, Scheme_Http): authority = ":80" else: authority = ":443" @@ -228,7 +228,7 @@ def strip_forbidden_headers(headers:MutableMapping[str, str]) -> MutableMapping[ pass return headers -async def copy(bytes:bytes, tx:ByteStreamWriter): +async def _copy(bytes: bytes | None, tx: ByteStreamWriter) -> None: with tx: if bytes is not None: await tx.write_all(bytes) diff --git a/src/spin_sdk/key_value.py b/src/spin_sdk/key_value.py index 08494b9..569c4c6 100644 --- a/src/spin_sdk/key_value.py +++ b/src/spin_sdk/key_value.py @@ -1,6 +1,8 @@ """Module for accessing Spin key-value stores""" -from spin_sdk.wit.imports.spin_key_value_key_value_3_0_0 import Store +from spin_sdk.wit.imports import spin_key_value_key_value_3_0_0 as kv + +Store = kv.Store async def open(name: str) -> Store: """ diff --git a/src/spin_sdk/llm.py b/src/spin_sdk/llm.py index 39ad14c..f24269a 100644 --- a/src/spin_sdk/llm.py +++ b/src/spin_sdk/llm.py @@ -1,9 +1,10 @@ """Module for working with the Spin large language model API""" from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Optional, List from spin_sdk.wit.imports import fermyon_spin_llm_2_0_0 as spin_llm + @dataclass class InferencingParams: max_tokens: int = 100 @@ -14,7 +15,7 @@ class InferencingParams: top_p: float = 0.9 -def generate_embeddings(model: str, text: Sequence[str]) -> spin_llm.EmbeddingsResult: +def generate_embeddings(model: str, text: List[str]) -> spin_llm.EmbeddingsResult: """ A `componentize_py_types.Err(spin_sdk.wit.imports.fermyon_spin_llm_2_0_0.Error_ModelNotSupported)` will be raised if the component does not have access to the specified model. @@ -32,8 +33,16 @@ def infer_with_options(model: str, prompt: str, options: Optional[InferencingPar A `componentize_py_types.Err(spin_sdk.wit.imports.fermyon_spin_llm_2_0_0.Error_InvalidInput(str))` will be raised if an invalid input is provided. """ - options = options or InferencingParams - return spin_llm.infer(model, prompt, options) + some_options = options or InferencingParams() + my_options = spin_llm.InferencingParams( + some_options.max_tokens, + some_options.repeat_penalty, + some_options.repeat_penalty_last_n_token_count, + some_options.temperature, + some_options.top_k, + some_options.top_p, + ) + return spin_llm.infer(model, prompt, my_options) def infer(model: str, prompt: str) -> spin_llm.InferencingResult: """ @@ -43,6 +52,5 @@ def infer(model: str, prompt: str) -> spin_llm.InferencingResult: A `componentize_py_types.Err(spin_sdk.wit.imports.fermyon_spin_llm_2_0_0.Error_InvalidInput(str))` will be raised if an invalid input is provided. """ - options = InferencingParams - return spin_llm.infer(model, prompt, options) + return infer_with_options(model, prompt, None) diff --git a/src/spin_sdk/postgres.py b/src/spin_sdk/postgres.py index 785d32b..0ed6ddb 100644 --- a/src/spin_sdk/postgres.py +++ b/src/spin_sdk/postgres.py @@ -1,6 +1,10 @@ """Module for interacting with a Postgres database""" -from spin_sdk.wit.imports.spin_postgres_postgres_4_2_0 import Connection +from spin_sdk.wit.imports import spin_postgres_postgres_4_2_0 as pg + +Connection = pg.Connection +RowSet = pg.RowSet +DbValue = pg.DbValue async def open(connection_string: str) -> Connection: """ diff --git a/src/spin_sdk/sqlite.py b/src/spin_sdk/sqlite.py index c860777..6a39bb4 100644 --- a/src/spin_sdk/sqlite.py +++ b/src/spin_sdk/sqlite.py @@ -1,9 +1,13 @@ """Module for interacting with an SQLite database""" from typing import List -from spin_sdk.wit.imports.spin_sqlite_sqlite_3_1_0 import ( - Connection, Value_Integer, Value_Real, Value_Text, Value_Blob -) +from spin_sdk.wit.imports import spin_sqlite_sqlite_3_1_0 as sqlite + +Connection = sqlite.Connection +Value_Integer = sqlite.Value_Integer +Value_Real = sqlite.Value_Real +Value_Text = sqlite.Value_Text +Value_Blob = sqlite.Value_Blob async def open(name: str) -> Connection: """Open a connection to a named database instance. diff --git a/src/spin_sdk/variables.py b/src/spin_sdk/variables.py index dba0f25..3486933 100644 --- a/src/spin_sdk/variables.py +++ b/src/spin_sdk/variables.py @@ -2,7 +2,7 @@ from spin_sdk.wit.imports import spin_variables_variables_3_0_0 as variables -async def get(key: str): +async def get(key: str) -> str: """ Gets the value of the given key """