Skip to content

Commit 7f36b6a

Browse files
committed
Added workaround for API prompting and saving
1 parent 539b73a commit 7f36b6a

4 files changed

Lines changed: 63 additions & 61 deletions

File tree

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ attrs==24.2.0
77
certifi==2024.8.30
88
charset-normalizer==3.3.2
99
click==8.1.7
10-
colorama==0.4.6
1110
distro==1.9.0
1211
docstring_parser==0.16
1312
frozenlist==1.4.1

setup.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,6 @@
11
from setuptools import setup, find_packages
2-
from setuptools.command.install import install
32
import os
43

5-
DB_PATH = os.path.expanduser('~/.snapshell/system_info.db')
6-
7-
class CustomInstallCommand(install):
8-
def run(self):
9-
10-
# ANSI escape codes for colors
11-
GREEN = '\033[92m'
12-
YELLOW = '\033[93m'
13-
RESET = '\033[0m'
14-
15-
# Prompt user for GROQ API key
16-
print(GREEN + "Please enter your GROQ API key:" + RESET)
17-
groq_api_key = input(YELLOW + "> " + RESET)
18-
19-
# Detect the user's shell
20-
user_shell = os.environ.get('SHELL', '/bin/bash')
21-
shell_config_file = {
22-
'/bin/bash': '~/.bashrc',
23-
'/bin/zsh': '~/.zshrc',
24-
'/bin/fish': '~/.config/fish/config.fish'
25-
}.get(user_shell)
26-
27-
if shell_config_file:
28-
shell_config_path = os.path.expanduser(shell_config_file)
29-
30-
# Set the GROQ API key in the user's shell config file
31-
with open(shell_config_path, "a") as shell_config:
32-
shell_config.write(f'\nexport HELPER_GROQ_API_KEY="{groq_api_key}"\n')
33-
print(GREEN + f"Added GROQ API key to {shell_config_path}" + RESET)
34-
35-
# Reload shell configuration
36-
os.system(f'source {shell_config_path}')
37-
from snapshell.database import create_database, update_database
38-
print(YELLOW + f"Setting up the database... at {DB_PATH}" + RESET)
39-
create_database()
40-
install.run(self)
41-
update_database()
42-
43-
print(GREEN + "Database successfully created\n" + RESET)
44-
print(GREEN + "use the tool as snapshell\n" + RESET)
45-
# Restart the shell
46-
os.execvp(user_shell, [user_shell])
47-
48-
else:
49-
print(f"Unrecognized shell: {user_shell}. Please manually add the following lines to your shell configuration file:")
50-
print(f'export HELPER_GROQ_API_KEY="{groq_api_key}"')
51-
524
setup(
535
name='snapshell',
546
version='1.0.0',
@@ -81,7 +33,4 @@ def run(self):
8133
'Operating System :: OS Independent',
8234
],
8335
python_requires='>=3.10',
84-
cmdclass={
85-
'install': CustomInstallCommand,
86-
},
8736
)

snapshell/cli.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import argparse
2-
from .llm_api import suggest_command
3-
from .database import update_database, DB_PATH
4-
import sqlite3
52
import os
3+
import sqlite3
4+
from snapshell.database import create_database, update_database, DB_PATH
5+
from .llm_api import suggest_command, set_api_key, load_api_key
66

77
# ANSI escape codes for colors
88
RESET = "\033[0m"
@@ -13,6 +13,20 @@
1313
YELLOW = "\033[33m"
1414
RED = "\033[31m"
1515

16+
def initial_setup():
17+
# Prompt user for GROQ API key
18+
print(GREEN + "Please enter your GROQ API key:" + RESET)
19+
groq_api_key = input(YELLOW + "> " + RESET)
20+
21+
# Set the API key
22+
set_api_key(groq_api_key)
23+
24+
print(YELLOW + f"Setting up the database... at {DB_PATH}" + RESET)
25+
create_database()
26+
update_database()
27+
print(GREEN + "Database successfully created\n" + RESET)
28+
print(GREEN + "Use the tool as snapshell\n" + RESET)
29+
1630
def view_history():
1731
conn = sqlite3.connect(DB_PATH)
1832
cursor = conn.cursor()
@@ -46,8 +60,22 @@ def main():
4660
parser.add_argument('--update-db', action='store_true', help="Update the database with installed packages")
4761
parser.add_argument('--view-history', action='store_true', help="View command history")
4862
parser.add_argument('--clear-history', action='store_true', help="Clear command history")
63+
parser.add_argument('--set-api-key', type=str, help="Set the GROQ API key")
4964
args = parser.parse_args()
5065

66+
if args.set_api_key:
67+
set_api_key(args.set_api_key)
68+
print(GREEN + "API key set successfully." + RESET)
69+
return
70+
71+
# Load the API key from the configuration file
72+
API_KEY = load_api_key()
73+
74+
75+
if not API_KEY:
76+
print("This is working")
77+
initial_setup()
78+
5179
if args.update_db:
5280
print(f"{YELLOW}Updating database...{RESET}")
5381
update_database()
@@ -60,7 +88,7 @@ def main():
6088
if args.clear_history:
6189
clear_history()
6290
return
63-
91+
6492
print(f"{CYAN}Welcome to the Linux Command Tool. Type 'exit' to quit.{RESET}")
6593

6694
conversation_history = []

snapshell/llm_api.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
# llm_api.py
22
import os
3+
import json
34
from groq import Groq
45
from pydantic import BaseModel, Field, ValidationError
56
from .system_info import fetch_system_info
67
from .package_managers import detect_package_manager
78
import sqlite3
89
from .database import DB_PATH, save_command_suggestion
910

11+
# Configuration file path
12+
CONFIG_FILE = os.path.expanduser("~/.snapshell_config.json")
1013

11-
API_KEY = os.getenv('HELPER_GROQ_API_KEY')
12-
13-
groq = Groq(api_key=API_KEY)
14+
# Global variable to store the API key
15+
API_KEY = None
1416

1517
# Data model for LLM to generate
1618
class CommandSuggestion(BaseModel):
@@ -21,9 +23,26 @@ class CommandSuggestion(BaseModel):
2123
class SQLQuery(BaseModel):
2224
query: str = Field(description="The SQL query to be executed")
2325

26+
def load_api_key():
27+
global API_KEY
28+
if os.path.exists(CONFIG_FILE):
29+
with open(CONFIG_FILE, "r") as config_file:
30+
config = json.load(config_file)
31+
API_KEY = config.get("api_key")
32+
return API_KEY
33+
34+
def set_api_key(api_key):
35+
global API_KEY
36+
API_KEY = api_key
37+
config = {"api_key": api_key}
38+
with open(CONFIG_FILE, "w") as config_file:
39+
json.dump(config, config_file)
40+
2441
def suggest_command(user_input, conversation_history):
2542
if not API_KEY:
26-
raise ValueError("API key not found. Please set the GROQ_API_KEY environment variable.")
43+
raise ValueError("API key not set. Please set the API key using set_api_key function.")
44+
45+
groq = Groq(api_key=API_KEY)
2746

2847
system_info = fetch_system_info()
2948
package_manager = detect_package_manager()
@@ -103,6 +122,8 @@ def fallback_to_llm(user_input, package_manager, fallback_message, conversation_
103122

104123
messages.append({"role": "user", "content": user_input})
105124

125+
groq = Groq(api_key=API_KEY)
126+
106127
chat_completion = groq.chat.completions.create(
107128
messages=messages,
108129
model="llama-3.2-90b-text-preview",
@@ -131,6 +152,8 @@ def formulate_sql_query(user_input):
131152
{"role": "user", "content": user_input},
132153
]
133154

155+
groq = Groq(api_key=API_KEY)
156+
134157
chat_completion = groq.chat.completions.create(
135158
messages=messages,
136159
model="llama-3.2-90b-text-preview",
@@ -170,4 +193,7 @@ def query_database(sql_query):
170193
return relevant_packages
171194
except sqlite3.Error as e:
172195
# If there's any issue with the database query, return an empty result
173-
return []
196+
return []
197+
198+
# Load the API key from the configuration file on module import
199+
load_api_key()

0 commit comments

Comments
 (0)