44from abc import ABC , abstractmethod
55from typing import Literal
66
7- from codesectools .datasets import DATASETS_ALL
8- from codesectools .utils import (
9- USER_CONFIG_DIR ,
10- )
7+ import typer
8+ from git import Repo
9+ from rich import print
10+ from rich .panel import Panel
11+ from rich .progress import Progress
12+
13+ from codesectools .utils import USER_CACHE_DIR , USER_CONFIG_DIR
1114
1215
1316class SASTRequirement (ABC ):
@@ -24,7 +27,7 @@ def __init__(
2427
2528 Args:
2629 name: The name of the requirement.
27- instruction: A short instruction on how to fulfill the requirement.
30+ instruction: A short instruction on how to download the requirement.
2831 url: A URL for more detailed instructions.
2932 doc: A flag indicating if the instruction is available in the documentaton.
3033
@@ -44,6 +47,36 @@ def __repr__(self) -> str:
4447 return f"{ self .__class__ .__name__ } ({ self .name } )"
4548
4649
50+ class DownloadableRequirement (SASTRequirement ):
51+ """Represent a SAST requirement that can be downloaded automatically."""
52+
53+ def __init__ (
54+ self ,
55+ name : str ,
56+ instruction : str | None = None ,
57+ url : str | None = None ,
58+ doc : bool = False ,
59+ ) -> None :
60+ """Initialize a DownloadableRequirement instance.
61+
62+ Sets a standard instruction message on how to download the requirement using the CLI.
63+
64+ Args:
65+ name: The name of the requirement.
66+ instruction: A short instruction on how to download the requirement.
67+ url: A URL for more detailed instructions.
68+ doc: A flag indicating if the instruction is available in the documentaton.
69+
70+ """
71+ instruction = f"cstools download { name } "
72+ super ().__init__ (name , instruction , url , doc )
73+
74+ @abstractmethod
75+ def download (self , ** kwargs : dict ) -> None :
76+ """Download the requirement."""
77+ pass
78+
79+
4780class Config (SASTRequirement ):
4881 """Represent a configuration file requirement for a SAST tool."""
4982
@@ -58,7 +91,7 @@ def __init__(
5891
5992 Args:
6093 name: The name of the requirement.
61- instruction: A short instruction on how to fulfill the requirement.
94+ instruction: A short instruction on how to download the requirement.
6295 url: A URL for more detailed instructions.
6396 doc: A flag indicating if this is a documentation-only requirement.
6497
@@ -84,7 +117,7 @@ def __init__(
84117
85118 Args:
86119 name: The name of the requirement.
87- instruction: A short instruction on how to fulfill the requirement.
120+ instruction: A short instruction on how to download the requirement.
88121 url: A URL for more detailed instructions.
89122 doc: A flag indicating if this is a documentation-only requirement.
90123
@@ -96,31 +129,61 @@ def is_fulfilled(self, **kwargs: dict) -> bool:
96129 return bool (shutil .which (self .name ))
97130
98131
99- class DatasetCache ( SASTRequirement ):
100- """Represent a dataset cache requirement for a SAST tool ."""
132+ class GitRepo ( DownloadableRequirement ):
133+ """Represent a Git repository requirement that can be downloaded ."""
101134
102135 def __init__ (
103136 self ,
104137 name : str ,
138+ repo_url : str ,
105139 instruction : str | None = None ,
106140 url : str | None = None ,
107141 doc : bool = False ,
108142 ) -> None :
109- """Initialize a DatasetCache instance.
143+ """Initialize a GitRepo requirement instance.
110144
111145 Args:
112146 name: The name of the requirement.
113- instruction: A short instruction on how to fulfill the requirement.
147+ repo_url: The URL of the Git repository to clone.
148+ instruction: A short instruction on how to download the requirement.
114149 url: A URL for more detailed instructions.
115- doc: A flag indicating if this is a documentation-only requirement .
150+ doc: A flag indicating if the instruction is available in the documentaton .
116151
117152 """
118- instruction = f"cstools dataset download { name } "
119153 super ().__init__ (name , instruction , url , doc )
154+ self .repo_url = repo_url
155+ self .directory = USER_CACHE_DIR / self .name
120156
121157 def is_fulfilled (self , ** kwargs : dict ) -> bool :
122- """Check if the dataset is cached locally."""
123- return DATASETS_ALL [self .name ].is_cached ()
158+ """Check if the Git repository has been cloned."""
159+ return (self .directory / ".complete" ).is_file ()
160+
161+ def download (self , ** kwargs : dict ) -> None :
162+ """Prompt for license agreement and clone the Git repository."""
163+ panel = Panel (
164+ f"""Git repository:\t [b]{ self .name } [/b]
165+ Repository URL:\t [u]{ self .repo_url } [/u]
166+
167+ Please review the license of the repository at the URL above.
168+ By proceeding, you agree to abide by its terms.""" ,
169+ title = "[b]License Agreement[/b]" ,
170+ )
171+ print (panel )
172+
173+ agreed = typer .confirm ("Do you accept the license terms and wish to proceed?" )
174+ if not agreed :
175+ print ("[red]License agreement declined. Download aborted.[/red]" )
176+ raise typer .Exit (code = 1 )
177+
178+ with Progress () as progress :
179+ progress .add_task (f"Cloning repository [b]{ self .name } [/b]..." , total = None )
180+ Repo .clone_from (
181+ self .repo_url ,
182+ self .directory ,
183+ depth = 1 ,
184+ )
185+ (self .directory / ".complete" ).write_bytes (b"\x42 " )
186+ print (f"[b]{ self .name } [/b] has been downloaded at { self .directory } ." )
124187
125188
126189class SASTRequirements :
@@ -137,25 +200,26 @@ def __init__(
137200
138201 """
139202 self .name = None
140- self .full_reqs = full_reqs
141- self .partial_reqs = partial_reqs
203+ self .full = full_reqs
204+ self .partial = partial_reqs
205+ self .all = full_reqs + partial_reqs
142206
143207 def get_status (self ) -> Literal ["full" ] | Literal ["partial" ] | Literal ["none" ]:
144208 """Determine the operational status (full, partial, none) based on fulfilled requirements."""
145209 # full: can run sast analysis and result parsing
146210 # partial: can run result parsing
147211 # none: nothing
148212 status = "none"
149- if all (req .is_fulfilled (sast_name = self .name ) for req in self .partial_reqs ):
213+ if all (req .is_fulfilled (sast_name = self .name ) for req in self .partial ):
150214 status = "partial"
151- if all (req .is_fulfilled (sast_name = self .name ) for req in self .full_reqs ):
215+ if all (req .is_fulfilled (sast_name = self .name ) for req in self .full ):
152216 status = "full"
153217 return status
154218
155219 def get_missing (self ) -> list [SASTRequirement ]:
156220 """Get a list of all unfulfilled requirements."""
157221 missing = []
158- for req in self .full_reqs + self . partial_reqs :
222+ for req in self .all :
159223 if not req .is_fulfilled (sast_name = self .name ):
160224 missing .append (req )
161225 return missing
0 commit comments