|
8 | 8 | import urllib.request |
9 | 9 | import urllib.error |
10 | 10 |
|
| 11 | +try: |
| 12 | + import openai |
| 13 | +except ImportError: |
| 14 | + openai = None |
| 15 | + |
11 | 16 | def parse_arguments(): |
12 | 17 | ap = argparse.ArgumentParser( |
13 | 18 | description="Client for stable-diffusion.cpp sd-server", |
@@ -113,7 +118,7 @@ def parse_arguments(): |
113 | 118 | return util_opts, gen_opts |
114 | 119 |
|
115 | 120 |
|
116 | | -def build_openai_payload(gen_opts, util_opts): |
| 121 | +def build_openai_payload(gen_opts, util_opts, has_output_format=True): |
117 | 122 | extension_data = {} |
118 | 123 | api_data = {} |
119 | 124 |
|
@@ -143,7 +148,10 @@ def build_openai_payload(gen_opts, util_opts): |
143 | 148 | extension_data["height"] = height |
144 | 149 |
|
145 | 150 | if gen_opts.get("output_format"): |
146 | | - api_data["output_format"] = gen_opts["output_format"] |
| 151 | + if has_output_format: |
| 152 | + api_data["output_format"] = gen_opts["output_format"] |
| 153 | + else: |
| 154 | + api_data['extra_body'] = {"output_format": gen_opts["output_format"]} |
147 | 155 |
|
148 | 156 | if gen_opts.get("batch_count"): |
149 | 157 | api_data["n"] = gen_opts["batch_count"] |
@@ -232,36 +240,58 @@ def main(): |
232 | 240 |
|
233 | 241 | if not server_url.endswith('/'): |
234 | 242 | server_url += '/' |
235 | | - endpoint = server_url + "v1/images/generations" |
| 243 | + base_url = server_url + "v1" |
236 | 244 |
|
237 | | - api_payload = build_openai_payload(gen_opts, util_opts) |
| 245 | + if openai: |
238 | 246 |
|
239 | | - if verbose: |
240 | | - print(f"Sending request to: {endpoint}") |
241 | | - print(f"Payload: {json.dumps(api_payload, indent=2)}") |
| 247 | + openai_version = openai.version.VERSION |
| 248 | + if verbose: |
| 249 | + print(f"Using OpenAI module {openai_version}") |
242 | 250 |
|
243 | | - req_data = json.dumps(api_payload).encode('utf-8') |
244 | | - req = urllib.request.Request(endpoint, data=req_data, headers={'Content-Type': 'application/json'}) |
| 251 | + # the output_format parameter was added to the openai module only in version |
| 252 | + # 1.76; for simplicity, always request it through the extra_body parameter |
| 253 | + has_output_format = False |
| 254 | + api_parameters = build_openai_payload(gen_opts, util_opts, has_output_format) |
245 | 255 |
|
246 | | - response_body = None |
247 | | - try: |
248 | | - with urllib.request.urlopen(req) as response: |
249 | | - response_body = response.read().decode('utf-8') |
250 | | - except urllib.error.HTTPError as e: |
251 | | - print(f"HTTP Error {e.code}: {e.reason}") |
252 | | - sys.exit(1) |
253 | | - except urllib.error.URLError as e: |
254 | | - print(f"URL Error: {e.reason}") |
255 | | - sys.exit(1) |
256 | | - except Exception as e: |
257 | | - print(f"Request Error: {e}") |
258 | | - sys.exit(1) |
| 256 | + if verbose: |
| 257 | + print(f"Base URL: {base_url}") |
| 258 | + print(f"Parameters: {json.dumps(api_parameters, indent=2)}") |
259 | 259 |
|
260 | | - try: |
261 | | - images = decode_openai_response(response_body) |
262 | | - except ValueError as e: |
263 | | - print(f"Error decoding response: {e}") |
264 | | - sys.exit(1) |
| 260 | + client = openai.OpenAI(api_key="local-api-key", base_url=base_url) |
| 261 | + result = client.images.generate(**api_parameters) |
| 262 | + images = [base64.b64decode(img.b64_json) for img in result.data] |
| 263 | + |
| 264 | + else: |
| 265 | + |
| 266 | + api_payload = build_openai_payload(gen_opts, util_opts) |
| 267 | + |
| 268 | + endpoint = base_url + "/images/generations" |
| 269 | + if verbose: |
| 270 | + print(f"Sending request to: {endpoint}") |
| 271 | + print(f"Payload: {json.dumps(api_payload, indent=2)}") |
| 272 | + |
| 273 | + req_data = json.dumps(api_payload).encode('utf-8') |
| 274 | + req = urllib.request.Request(endpoint, data=req_data, headers={'Content-Type': 'application/json'}) |
| 275 | + |
| 276 | + response_body = None |
| 277 | + try: |
| 278 | + with urllib.request.urlopen(req) as response: |
| 279 | + response_body = response.read().decode('utf-8') |
| 280 | + except urllib.error.HTTPError as e: |
| 281 | + print(f"HTTP Error {e.code}: {e.reason}") |
| 282 | + sys.exit(1) |
| 283 | + except urllib.error.URLError as e: |
| 284 | + print(f"URL Error: {e.reason}") |
| 285 | + sys.exit(1) |
| 286 | + except Exception as e: |
| 287 | + print(f"Request Error: {e}") |
| 288 | + sys.exit(1) |
| 289 | + |
| 290 | + try: |
| 291 | + images = decode_openai_response(response_body) |
| 292 | + except ValueError as e: |
| 293 | + print(f"Error decoding response: {e}") |
| 294 | + sys.exit(1) |
265 | 295 |
|
266 | 296 | save_images(images, util_opts) |
267 | 297 |
|
|
0 commit comments