Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion api_tokens_sample.yml
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
api_key: # Insert api key here
api_key:
- # Insert api key here
- # Insert additional api key here (optional)
admin_key: # Insert admin key here
32 changes: 22 additions & 10 deletions common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import secrets
from ruamel.yaml import YAML
from fastapi import Header, HTTPException, Request
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from loguru import logger
from typing import Optional

Expand All @@ -18,22 +18,30 @@
class AuthKeys(BaseModel):
"""
This class represents the authentication keys for the application.
It contains two types of keys: 'api_key' and 'admin_key'.
The 'api_key' is used for general API calls, while the 'admin_key'
is used for administrative tasks. The class also provides a method
to verify if a given key matches the stored 'api_key' or 'admin_key'.
It contains two types of keys: 'api_key' (list) and 'admin_key'.
The 'api_key' can be a single string or a list of strings for general API calls,
while the 'admin_key' is used for administrative tasks. The class also provides
a method to verify if a given key matches any stored 'api_key' or 'admin_key'.
"""

api_key: str
api_key: list[str]
admin_key: str

@field_validator("api_key", mode="before")
@classmethod
def normalize_api_key(cls, v):
if isinstance(v, str):
return [v]
if isinstance(v, list):
return v
return []

def verify_key(self, test_key: str, key_type: str):
"""Verify if a given key matches the stored key."""
if key_type == "admin_key":
return test_key == self.admin_key
if key_type == "api_key":
# Admin keys are valid for all API calls
return test_key == self.api_key or test_key == self.admin_key
return test_key in self.api_key or test_key == self.admin_key
return False


Expand All @@ -60,15 +68,18 @@ async def load_auth_keys(disable_from_config: bool):

# Create a temporary YAML parser
yaml = YAML(typ=["rt", "safe"])
yaml.default_flow_style = False

try:
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
contents = await auth_file.read()
auth_keys_dict = yaml.load(contents)
if "api_keys" in auth_keys_dict and "api_key" not in auth_keys_dict:
auth_keys_dict["api_key"] = auth_keys_dict.pop("api_keys")
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except FileNotFoundError:
new_auth_keys = AuthKeys(
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
api_key=[secrets.token_hex(16)], admin_key=secrets.token_hex(16)
)
AUTH_KEYS = new_auth_keys

Expand All @@ -78,8 +89,9 @@ async def load_auth_keys(disable_from_config: bool):

await auth_file.write(string_stream.getvalue())

api_keys_str = "\n".join([f" - {key}" for key in AUTH_KEYS.api_key])
logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your API keys are:\n{api_keys_str}\n"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "
"and restart the server. Have fun!"
Expand Down