Skip to content

Commit df8b0ba

Browse files
authored
Updates to the fine tuning SDK + addition of pagination primitives (#582)
* Add support for new fine_tuning SDK + pagination primitives * typo
1 parent b8dfa35 commit df8b0ba

File tree

7 files changed

+233
-0
lines changed

7 files changed

+233
-0
lines changed

openai/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ErrorObject,
2929
File,
3030
FineTune,
31+
FineTuningJob,
3132
Image,
3233
Model,
3334
Moderation,
@@ -84,6 +85,7 @@
8485
"ErrorObject",
8586
"File",
8687
"FineTune",
88+
"FineTuningJob",
8789
"InvalidRequestError",
8890
"Model",
8991
"Moderation",

openai/api_resources/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from openai.api_resources.error_object import ErrorObject # noqa: F401
1010
from openai.api_resources.file import File # noqa: F401
1111
from openai.api_resources.fine_tune import FineTune # noqa: F401
12+
from openai.api_resources.fine_tuning import FineTuningJob # noqa: F401
1213
from openai.api_resources.image import Image # noqa: F401
1314
from openai.api_resources.model import Model # noqa: F401
1415
from openai.api_resources.moderation import Moderation # noqa: F401

openai/api_resources/abstract/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@
77
from openai.api_resources.abstract.nested_resource_class_methods import (
88
nested_resource_class_methods,
99
)
10+
from openai.api_resources.abstract.paginatable_api_resource import (
11+
PaginatableAPIResource,
12+
)
1013
from openai.api_resources.abstract.updateable_api_resource import UpdateableAPIResource

openai/api_resources/abstract/nested_resource_class_methods.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,19 @@ def list_nested_resources(cls, id, **params):
124124
list_method = "list_%s" % resource_plural
125125
setattr(cls, list_method, classmethod(list_nested_resources))
126126

127+
elif operation == "paginated_list":
128+
129+
def paginated_list_nested_resources(
130+
cls, id, limit=None, cursor=None, **params
131+
):
132+
url = getattr(cls, resource_url_method)(id)
133+
return getattr(cls, resource_request_method)(
134+
"get", url, limit=limit, cursor=cursor, **params
135+
)
136+
137+
list_method = "list_%s" % resource_plural
138+
setattr(cls, list_method, classmethod(paginated_list_nested_resources))
139+
127140
else:
128141
raise ValueError("Unknown operation: %s" % operation)
129142

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from openai import api_requestor, error, util
2+
from openai.api_resources.abstract.listable_api_resource import ListableAPIResource
3+
from openai.util import ApiType
4+
5+
6+
class PaginatableAPIResource(ListableAPIResource):
7+
@classmethod
8+
def auto_paging_iter(cls, *args, **params):
9+
next_cursor = None
10+
has_more = True
11+
if not params.get("limit"):
12+
params["limit"] = 20
13+
while has_more:
14+
if next_cursor:
15+
params["after"] = next_cursor
16+
response = cls.list(*args, **params)
17+
18+
for item in response.data:
19+
yield item
20+
21+
if response.data:
22+
next_cursor = response.data[-1].id
23+
has_more = response.has_more
24+
25+
@classmethod
26+
def __prepare_list_requestor(
27+
cls,
28+
api_key=None,
29+
api_version=None,
30+
organization=None,
31+
api_base=None,
32+
api_type=None,
33+
):
34+
requestor = api_requestor.APIRequestor(
35+
api_key,
36+
api_base=api_base or cls.api_base(),
37+
api_version=api_version,
38+
api_type=api_type,
39+
organization=organization,
40+
)
41+
42+
typed_api_type, api_version = cls._get_api_type_and_version(
43+
api_type, api_version
44+
)
45+
46+
if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
47+
base = cls.class_url()
48+
url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
49+
elif typed_api_type == ApiType.OPEN_AI:
50+
url = cls.class_url()
51+
else:
52+
raise error.InvalidAPIType("Unsupported API type %s" % api_type)
53+
return requestor, url
54+
55+
@classmethod
56+
def list(
57+
cls,
58+
limit=None,
59+
starting_after=None,
60+
api_key=None,
61+
request_id=None,
62+
api_version=None,
63+
organization=None,
64+
api_base=None,
65+
api_type=None,
66+
**params,
67+
):
68+
requestor, url = cls.__prepare_list_requestor(
69+
api_key,
70+
api_version,
71+
organization,
72+
api_base,
73+
api_type,
74+
)
75+
76+
params = {
77+
**params,
78+
"limit": limit,
79+
"starting_after": starting_after,
80+
}
81+
82+
response, _, api_key = requestor.request(
83+
"get", url, params, request_id=request_id
84+
)
85+
openai_object = util.convert_to_openai_object(
86+
response, api_key, api_version, organization
87+
)
88+
openai_object._retrieve_params = params
89+
return openai_object
90+
91+
@classmethod
92+
async def alist(
93+
cls,
94+
limit=None,
95+
starting_after=None,
96+
api_key=None,
97+
request_id=None,
98+
api_version=None,
99+
organization=None,
100+
api_base=None,
101+
api_type=None,
102+
**params,
103+
):
104+
requestor, url = cls.__prepare_list_requestor(
105+
api_key,
106+
api_version,
107+
organization,
108+
api_base,
109+
api_type,
110+
)
111+
112+
params = {
113+
**params,
114+
"limit": limit,
115+
"starting_after": starting_after,
116+
}
117+
118+
response, _, api_key = await requestor.arequest(
119+
"get", url, params, request_id=request_id
120+
)
121+
openai_object = util.convert_to_openai_object(
122+
response, api_key, api_version, organization
123+
)
124+
openai_object._retrieve_params = params
125+
return openai_object

openai/api_resources/fine_tuning.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from urllib.parse import quote_plus
2+
3+
from openai import error
4+
from openai.api_resources.abstract import (
5+
CreateableAPIResource,
6+
PaginatableAPIResource,
7+
nested_resource_class_methods,
8+
)
9+
from openai.api_resources.abstract.deletable_api_resource import DeletableAPIResource
10+
from openai.util import ApiType
11+
12+
13+
@nested_resource_class_methods("event", operations=["paginated_list"])
14+
class FineTuningJob(
15+
PaginatableAPIResource, CreateableAPIResource, DeletableAPIResource
16+
):
17+
OBJECT_NAME = "fine_tuning.jobs"
18+
19+
@classmethod
20+
def _prepare_cancel(
21+
cls,
22+
id,
23+
api_key=None,
24+
api_type=None,
25+
request_id=None,
26+
api_version=None,
27+
**params,
28+
):
29+
base = cls.class_url()
30+
extn = quote_plus(id)
31+
32+
typed_api_type, api_version = cls._get_api_type_and_version(
33+
api_type, api_version
34+
)
35+
if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
36+
url = "/%s%s/%s/cancel?api-version=%s" % (
37+
cls.azure_api_prefix,
38+
base,
39+
extn,
40+
api_version,
41+
)
42+
elif typed_api_type == ApiType.OPEN_AI:
43+
url = "%s/%s/cancel" % (base, extn)
44+
else:
45+
raise error.InvalidAPIType("Unsupported API type %s" % api_type)
46+
47+
instance = cls(id, api_key, **params)
48+
return instance, url
49+
50+
@classmethod
51+
def cancel(
52+
cls,
53+
id,
54+
api_key=None,
55+
api_type=None,
56+
request_id=None,
57+
api_version=None,
58+
**params,
59+
):
60+
instance, url = cls._prepare_cancel(
61+
id,
62+
api_key,
63+
api_type,
64+
request_id,
65+
api_version,
66+
**params,
67+
)
68+
return instance.request("post", url, request_id=request_id)
69+
70+
@classmethod
71+
def acancel(
72+
cls,
73+
id,
74+
api_key=None,
75+
api_type=None,
76+
request_id=None,
77+
api_version=None,
78+
**params,
79+
):
80+
instance, url = cls._prepare_cancel(
81+
id,
82+
api_key,
83+
api_type,
84+
request_id,
85+
api_version,
86+
**params,
87+
)
88+
return instance.arequest("post", url, request_id=request_id)

openai/object_classes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
"fine-tune": api_resources.FineTune,
99
"model": api_resources.Model,
1010
"deployment": api_resources.Deployment,
11+
"fine_tuning.job": api_resources.FineTuningJob,
1112
}

0 commit comments

Comments
 (0)