Skip to content

Commit c011835

Browse files
committed
use openai python module if available
1 parent 5cbf147 commit c011835

1 file changed

Lines changed: 57 additions & 27 deletions

File tree

examples/server/client.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
import urllib.request
99
import urllib.error
1010

11+
try:
12+
import openai
13+
except ImportError:
14+
openai = None
15+
1116
def parse_arguments():
1217
ap = argparse.ArgumentParser(
1318
description="Client for stable-diffusion.cpp sd-server",
@@ -113,7 +118,7 @@ def parse_arguments():
113118
return util_opts, gen_opts
114119

115120

116-
def build_openai_payload(gen_opts, util_opts):
121+
def build_openai_payload(gen_opts, util_opts, has_output_format=True):
117122
extension_data = {}
118123
api_data = {}
119124

@@ -143,7 +148,10 @@ def build_openai_payload(gen_opts, util_opts):
143148
extension_data["height"] = height
144149

145150
if gen_opts.get("output_format"):
146-
api_data["output_format"] = gen_opts["output_format"]
151+
if has_output_format:
152+
api_data["output_format"] = gen_opts["output_format"]
153+
else:
154+
api_data['extra_body'] = {"output_format": gen_opts["output_format"]}
147155

148156
if gen_opts.get("batch_count"):
149157
api_data["n"] = gen_opts["batch_count"]
@@ -232,36 +240,58 @@ def main():
232240

233241
if not server_url.endswith('/'):
234242
server_url += '/'
235-
endpoint = server_url + "v1/images/generations"
243+
base_url = server_url + "v1"
236244

237-
api_payload = build_openai_payload(gen_opts, util_opts)
245+
if openai:
238246

239-
if verbose:
240-
print(f"Sending request to: {endpoint}")
241-
print(f"Payload: {json.dumps(api_payload, indent=2)}")
247+
openai_version = openai.version.VERSION
248+
if verbose:
249+
print(f"Using OpenAI module {openai_version}")
242250

243-
req_data = json.dumps(api_payload).encode('utf-8')
244-
req = urllib.request.Request(endpoint, data=req_data, headers={'Content-Type': 'application/json'})
251+
# the output_format parameter was added to the openai module only in version
252+
# 1.76; for simplicity, always request it through the extra_body parameter
253+
has_output_format = False
254+
api_parameters = build_openai_payload(gen_opts, util_opts, has_output_format)
245255

246-
response_body = None
247-
try:
248-
with urllib.request.urlopen(req) as response:
249-
response_body = response.read().decode('utf-8')
250-
except urllib.error.HTTPError as e:
251-
print(f"HTTP Error {e.code}: {e.reason}")
252-
sys.exit(1)
253-
except urllib.error.URLError as e:
254-
print(f"URL Error: {e.reason}")
255-
sys.exit(1)
256-
except Exception as e:
257-
print(f"Request Error: {e}")
258-
sys.exit(1)
256+
if verbose:
257+
print(f"Base URL: {base_url}")
258+
print(f"Parameters: {json.dumps(api_parameters, indent=2)}")
259259

260-
try:
261-
images = decode_openai_response(response_body)
262-
except ValueError as e:
263-
print(f"Error decoding response: {e}")
264-
sys.exit(1)
260+
client = openai.OpenAI(api_key="local-api-key", base_url=base_url)
261+
result = client.images.generate(**api_parameters)
262+
images = [base64.b64decode(img.b64_json) for img in result.data]
263+
264+
else:
265+
266+
api_payload = build_openai_payload(gen_opts, util_opts)
267+
268+
endpoint = base_url + "/images/generations"
269+
if verbose:
270+
print(f"Sending request to: {endpoint}")
271+
print(f"Payload: {json.dumps(api_payload, indent=2)}")
272+
273+
req_data = json.dumps(api_payload).encode('utf-8')
274+
req = urllib.request.Request(endpoint, data=req_data, headers={'Content-Type': 'application/json'})
275+
276+
response_body = None
277+
try:
278+
with urllib.request.urlopen(req) as response:
279+
response_body = response.read().decode('utf-8')
280+
except urllib.error.HTTPError as e:
281+
print(f"HTTP Error {e.code}: {e.reason}")
282+
sys.exit(1)
283+
except urllib.error.URLError as e:
284+
print(f"URL Error: {e.reason}")
285+
sys.exit(1)
286+
except Exception as e:
287+
print(f"Request Error: {e}")
288+
sys.exit(1)
289+
290+
try:
291+
images = decode_openai_response(response_body)
292+
except ValueError as e:
293+
print(f"Error decoding response: {e}")
294+
sys.exit(1)
265295

266296
save_images(images, util_opts)
267297

0 commit comments

Comments
 (0)