Skip to content

Commit 946aa8f

Browse files
authored
[headers] Allow user provided headers in completion (#116) (#71)
1 parent 62b51ca commit 946aa8f

File tree

4 files changed

+47
-9
lines changed

4 files changed

+47
-9
lines changed

openai/api_requestor.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def request(
100100
result = self.request_raw(
101101
method.lower(),
102102
url,
103-
params,
104-
headers,
103+
params=params,
104+
supplied_headers=headers,
105105
files=files,
106106
stream=stream,
107107
request_id=request_id,
@@ -212,18 +212,41 @@ def request_headers(
212212

213213
return headers
214214

215+
def _validate_headers(
216+
self, supplied_headers: Optional[Dict[str, str]]
217+
) -> Dict[str, str]:
218+
headers: Dict[str, str] = {}
219+
if supplied_headers is None:
220+
return headers
221+
222+
if not isinstance(supplied_headers, dict):
223+
raise TypeError("Headers must be a dictionary")
224+
225+
for k, v in supplied_headers.items():
226+
if not isinstance(k, str):
227+
raise TypeError("Header keys must be strings")
228+
if not isinstance(v, str):
229+
raise TypeError("Header values must be strings")
230+
headers[k] = v
231+
232+
# NOTE: It is possible to do more validation of the headers, but a request could always
233+
# be made to the API manually with invalid headers, so we need to handle them server side.
234+
235+
return headers
236+
215237
def request_raw(
216238
self,
217239
method,
218240
url,
241+
*,
219242
params=None,
220-
supplied_headers=None,
243+
supplied_headers: Dict[str, str] = None,
221244
files=None,
222-
stream=False,
245+
stream: bool = False,
223246
request_id: Optional[str] = None,
224247
) -> requests.Response:
225248
abs_url = "%s%s" % (self.api_base, url)
226-
headers = {}
249+
headers = self._validate_headers(supplied_headers)
227250

228251
data = None
229252
if method == "get" or method == "delete":
@@ -246,8 +269,6 @@ def request_raw(
246269
)
247270

248271
headers = self.request_headers(method, headers, request_id)
249-
if supplied_headers is not None:
250-
headers.update(supplied_headers)
251272

252273
util.log_info("Request to OpenAI API", method=method, path=abs_url)
253274
util.log_debug("Post details", data=data, api_version=self.api_version)

openai/api_resources/abstract/engine_api_resource.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def create(
6363
engine = params.pop("engine", None)
6464
timeout = params.pop("timeout", None)
6565
stream = params.get("stream", False)
66+
headers = params.pop("headers", None)
6667
if engine is None and cls.engine_required:
6768
raise error.InvalidRequestError(
6869
"Must provide an 'engine' parameter to create a %s" % cls, "engine"
@@ -87,7 +88,12 @@ def create(
8788
)
8889
url = cls.class_url(engine, api_type, api_version)
8990
response, _, api_key = requestor.request(
90-
"post", url, params, stream=stream, request_id=request_id
91+
"post",
92+
url,
93+
params=params,
94+
headers=headers,
95+
stream=stream,
96+
request_id=request_id,
9197
)
9298

9399
if stream:

openai/openai_object.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,12 @@ def request(
176176
organization=self.organization,
177177
)
178178
response, stream, api_key = requestor.request(
179-
method, url, params, stream=stream, headers=headers, request_id=request_id
179+
method,
180+
url,
181+
params=params,
182+
stream=stream,
183+
headers=headers,
184+
request_id=request_id,
180185
)
181186

182187
if stream:

openai/tests/test_endpoints.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@ def test_completions_multiple_prompts():
2828
prompt=["This was a test", "This was another test"], n=5, engine="ada"
2929
)
3030
assert len(result.choices) == 10
31+
32+
33+
def test_completions_model():
34+
result = openai.Completion.create(prompt="This was a test", n=5, model="ada")
35+
assert len(result.choices) == 5
36+
assert result.model.startswith("ada:")

0 commit comments

Comments
 (0)