diff --git a/api_tokens_sample.yml b/api_tokens_sample.yml index 9a0b75f9..a86279ff 100644 --- a/api_tokens_sample.yml +++ b/api_tokens_sample.yml @@ -1,2 +1,4 @@ -api_key: # Insert api key here +api_key: + - # Insert api key here + - # Insert additional api key here (optional) admin_key: # Insert admin key here \ No newline at end of file diff --git a/common/auth.py b/common/auth.py index b02cdd02..22eb0fec 100644 --- a/common/auth.py +++ b/common/auth.py @@ -8,7 +8,7 @@ import secrets from ruamel.yaml import YAML from fastapi import Header, HTTPException, Request -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from loguru import logger from typing import Optional @@ -18,22 +18,30 @@ class AuthKeys(BaseModel): """ This class represents the authentication keys for the application. - It contains two types of keys: 'api_key' and 'admin_key'. - The 'api_key' is used for general API calls, while the 'admin_key' - is used for administrative tasks. The class also provides a method - to verify if a given key matches the stored 'api_key' or 'admin_key'. + It contains two types of keys: 'api_key' (list) and 'admin_key'. + The 'api_key' can be a single string or a list of strings for general API calls, + while the 'admin_key' is used for administrative tasks. The class also provides + a method to verify if a given key matches any stored 'api_key' or 'admin_key'. """ - api_key: str + api_key: list[str] admin_key: str + @field_validator("api_key", mode="before") + @classmethod + def normalize_api_key(cls, v): + if isinstance(v, str): + return [v] + if isinstance(v, list): + return v + return [] + def verify_key(self, test_key: str, key_type: str): """Verify if a given key matches the stored key.""" if key_type == "admin_key": return test_key == self.admin_key if key_type == "api_key": - # Admin keys are valid for all API calls - return test_key == self.api_key or test_key == self.admin_key + return test_key in self.api_key or test_key == self.admin_key return False @@ -60,15 +68,18 @@ async def load_auth_keys(disable_from_config: bool): # Create a temporary YAML parser yaml = YAML(typ=["rt", "safe"]) + yaml.default_flow_style = False try: async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file: contents = await auth_file.read() auth_keys_dict = yaml.load(contents) + if "api_keys" in auth_keys_dict and "api_key" not in auth_keys_dict: + auth_keys_dict["api_key"] = auth_keys_dict.pop("api_keys") AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict) except FileNotFoundError: new_auth_keys = AuthKeys( - api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16) + api_key=[secrets.token_hex(16)], admin_key=secrets.token_hex(16) ) AUTH_KEYS = new_auth_keys @@ -78,8 +89,9 @@ async def load_auth_keys(disable_from_config: bool): await auth_file.write(string_stream.getvalue()) + api_keys_str = "\n".join([f" - {key}" for key in AUTH_KEYS.api_key]) logger.info( - f"Your API key is: {AUTH_KEYS.api_key}\n" + f"Your API keys are:\n{api_keys_str}\n" f"Your admin key is: {AUTH_KEYS.admin_key}\n\n" "If these keys get compromised, make sure to delete api_tokens.yml " "and restart the server. Have fun!"