diff --git a/.ci/ci.yml b/.ci/ci.yml index 7add12474..0f175cbde 100644 --- a/.ci/ci.yml +++ b/.ci/ci.yml @@ -69,7 +69,7 @@ stages: Write-Verbose -Verbose "Importing build utilities (buildtools.psd1)" Import-Module -Name $(Build.SourcesDirectory)/buildtools.psd1 -Force # - $(Build.SourcesDirectory)/build.ps1 -Build -Clean -BuildConfiguration Release -BuildFramework 'net472' + $(Build.SourcesDirectory)/build.ps1 -Build -Clean -BuildConfiguration Release -BuildFramework 'net8.0' displayName: Build Module - pwsh: | diff --git a/.ci/ci_auto.yml b/.ci/ci_auto.yml index 440d4753f..337563ce7 100644 --- a/.ci/ci_auto.yml +++ b/.ci/ci_auto.yml @@ -68,8 +68,8 @@ stages: Write-Verbose -Verbose "Importing build utilities (buildtools.psd1)" Import-Module -Name $(Build.SourcesDirectory)/buildtools.psd1 -Force # - # Build for net472 framework - $(Build.SourcesDirectory)/build.ps1 -Build -Clean -BuildConfiguration Release -BuildFramework 'net472' + # Build for net8.0 framework + $(Build.SourcesDirectory)/build.ps1 -Build -Clean -BuildConfiguration Release -BuildFramework 'net8.0' displayName: Build module - pwsh: | diff --git a/build.ps1 b/build.ps1 index 68de8f552..ae934a172 100644 --- a/build.ps1 +++ b/build.ps1 @@ -23,8 +23,8 @@ param ( [ValidateSet("Debug", "Release")] [string] $BuildConfiguration = "Debug", - [ValidateSet("net472")] - [string] $BuildFramework = "net472" + [ValidateSet("net8.0")] + [string] $BuildFramework = "net8.0" ) Import-Module -Name "$PSScriptRoot/buildtools.psd1" -Force diff --git a/doBuild.ps1 b/doBuild.ps1 index 15293d99b..fcb53b804 100644 --- a/doBuild.ps1 +++ b/doBuild.ps1 @@ -89,8 +89,12 @@ function DoBuild 'Azure.Core' 'Azure.Identity' 'Microsoft.Bcl.AsyncInterfaces' + 'Microsoft.Extensions.Caching.Abstractions' + 'Microsoft.Extensions.Caching.Memory' 'Microsoft.Extensions.FileProviders.Abstractions' 'Microsoft.Extensions.FileSystemGlobbing' + 'Microsoft.Extensions.Logging.Abstractions' + 'Microsoft.Extensions.Options' 'Microsoft.Extensions.Primitives' 'Microsoft.Identity.Client' 'Microsoft.Identity.Client.Extensions.Msal' @@ -107,20 +111,9 @@ function DoBuild 'NuGet.ProjectModel' 'NuGet.Protocol' 'NuGet.Versioning' - 'System.Buffers' - 'System.Diagnostics.DiagnosticSource' - 'System.IO.FileSystem.AccessControl' + 'OrasProject.Oras' 'System.Memory.Data' - 'System.Memory' - 'System.Numerics.Vectors' - 'System.Runtime.CompilerServices.Unsafe' - 'System.Security.AccessControl' 'System.Security.Cryptography.ProtectedData' - 'System.Security.Principal.Windows' - 'System.Text.Encodings.Web' - 'System.Text.Json' - 'System.Threading.Tasks.Extensions' - 'System.ValueTuple' ) $buildSuccess = $true diff --git a/src/code/ContainerRegistryServerAPICalls.cs b/src/code/ContainerRegistryServerAPICalls.cs index 9c17c0db0..93df80890 100644 --- a/src/code/ContainerRegistryServerAPICalls.cs +++ b/src/code/ContainerRegistryServerAPICalls.cs @@ -4,23 +4,28 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Collections.ObjectModel; using System.IO; using System.Linq; using System.Management.Automation; using System.Net; using System.Net.Http; -using System.Net.Http.Headers; using System.Security.Cryptography; using System.Text; using System.Text.Json; -using System.Text.RegularExpressions; +using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Memory; using Microsoft.PowerShell.PSResourceGet.Cmdlets; using Microsoft.PowerShell.PSResourceGet.UtilClasses; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; using NuGet.Versioning; +using OrasProject.Oras; +using OrasProject.Oras.Content; +using OrasProject.Oras.Oci; +using OrasProject.Oras.Registry; +using OrasProject.Oras.Registry.Remote; +using OrasProject.Oras.Registry.Remote.Auth; +using OrasRegistry = OrasProject.Oras.Registry.Remote.Registry; +using OrasRepository = OrasProject.Oras.Registry.Remote.Repository; namespace Microsoft.PowerShell.PSResourceGet { @@ -33,28 +38,15 @@ internal class ContainerRegistryServerAPICalls : ServerApiCall public override PSRepositoryInfo Repository { get; set; } public String Registry { get; set; } private readonly PSCmdlet _cmdletPassedIn; - private HttpClient _sessionClient { get; set; } private static readonly Hashtable[] emptyHashResponses = new Hashtable[] { }; private static FindResponseType containerRegistryFindResponseType = FindResponseType.ResponseString; private static readonly FindResults emptyResponseResults = new FindResults(stringResponse: Utils.EmptyStrArray, hashtableResponse: emptyHashResponses, responseType: containerRegistryFindResponseType); - const string containerRegistryRefreshTokenTemplate = "grant_type=access_token&service={0}&tenant={1}&access_token={2}"; // 0 - registry, 1 - tenant, 2 - access token - const string containerRegistryAccessTokenTemplate = "grant_type=refresh_token&service={0}&scope=repository:*:*&scope=registry:catalog:*&refresh_token={1}"; // 0 - registry, 1 - refresh token - const string containerRegistryOAuthExchangeUrlTemplate = "https://{0}/oauth2/exchange"; // 0 - registry - const string containerRegistryOAuthTokenUrlTemplate = "https://{0}/oauth2/token"; // 0 - registry - const string containerRegistryManifestUrlTemplate = "https://{0}/v2/{1}/manifests/{2}"; // 0 - registry, 1 - repo(modulename), 2 - tag(version) - const string containerRegistryBlobDownloadUrlTemplate = "https://{0}/v2/{1}/blobs/{2}"; // 0 - registry, 1 - repo(modulename), 2 - layer digest - const string containerRegistryFindImageVersionUrlTemplate = "https://{0}/v2/{1}/tags/list"; // 0 - registry, 1 - repo(modulename) - const string containerRegistryStartUploadTemplate = "https://{0}/v2/{1}/blobs/uploads/"; // 0 - registry, 1 - packagename - const string containerRegistryEndUploadTemplate = "https://{0}{1}&digest=sha256:{2}"; // 0 - registry, 1 - location, 2 - digest - const string defaultScope = "&scope=repository:*:*&scope=registry:catalog:*"; - const string catalogScope = "&scope=registry:catalog:*"; - const string grantTypeTemplate = "grant_type=access_token&service={0}{1}"; // 0 - registry, 1 - scope - const string authUrlTemplate = "{0}?service={1}{2}"; // 0 - realm, 1 - service, 2 - scope - - const string containerRegistryRepositoryListTemplate = "https://{0}/v2/_catalog"; // 0 - registry - - private string _cachedContainterRegistryToken = null; + // ORAS SDK objects + private readonly HttpClient _httpClient; + private readonly IClient _orasClient; + private readonly ICredentialProvider _credentialProvider; + private readonly IMemoryCache _memoryCache; #endregion @@ -65,16 +57,14 @@ public ContainerRegistryServerAPICalls(PSRepositoryInfo repository, PSCmdlet cmd Repository = repository; Registry = Repository.Uri.Host; _cmdletPassedIn = cmdletPassedIn; - HttpClientHandler handler = new HttpClientHandler() - { - Credentials = networkCredential - }; - _cachedContainterRegistryToken = null; + _httpClient = new HttpClient(); + _httpClient.Timeout = TimeSpan.FromMinutes(10); + _httpClient.DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", userAgentString); - _sessionClient = new HttpClient(handler); - _sessionClient.Timeout = TimeSpan.FromMinutes(10); - _sessionClient.DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", userAgentString); + _credentialProvider = new PSResourceGetCredentialProvider(repository, cmdletPassedIn); + _memoryCache = new MemoryCache(new MemoryCacheOptions()); + _orasClient = new Client(_httpClient, _credentialProvider, new Cache(_memoryCache)); } #endregion @@ -314,438 +304,208 @@ private Stream InstallVersion( _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::InstallVersion()"); errRecord = null; string packageNameLowercase = packageName.ToLower(); - string accessToken = string.Empty; - string tenantID = string.Empty; - string tempPath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString()); + try { - Directory.CreateDirectory(tempPath); - } - catch (Exception e) - { - errRecord = new ErrorRecord( - exception: e, - "InstallVersionTempDirCreationError", - ErrorCategory.InvalidResult, - _cmdletPassedIn); + // Create ORAS repository for the specific package + var repo = CreateOrasRepository(packageNameLowercase); - return null; - } + _cmdletPassedIn.WriteVerbose($"Fetching manifest for {packageNameLowercase} - {packageVersion}"); - string containerRegistryAccessToken = GetContainerRegistryAccessToken(needCatalogAccess: false, isPushOperation: false, out errRecord); - if (errRecord != null) - { - return null; - } + // Fetch the manifest by version tag + var (manifestDescriptor, manifestStream) = repo.FetchAsync(packageVersion).GetAwaiter().GetResult(); + byte[] manifestBytes; + using (manifestStream) + { + manifestBytes = manifestStream.ReadAllAsync(manifestDescriptor).GetAwaiter().GetResult(); + } - _cmdletPassedIn.WriteVerbose($"Getting manifest for {packageNameLowercase} - {packageVersion}"); - var manifest = GetContainerRegistryRepositoryManifest(packageNameLowercase, packageVersion, containerRegistryAccessToken, out errRecord); - if (errRecord != null) - { - return null; - } - string digest = GetDigestFromManifest(manifest, out errRecord); - if (errRecord != null) - { - return null; - } + // Parse the manifest to get the layer descriptor (contains the nupkg blob) + var manifest = System.Text.Json.JsonSerializer.Deserialize(manifestBytes); + if (manifest == null || manifest.Layers == null || manifest.Layers.Count == 0) + { + errRecord = new ErrorRecord( + exception: new InvalidOperationException($"Manifest for {packageNameLowercase} version {packageVersion} has no layers."), + "ManifestNoLayersError", + ErrorCategory.InvalidResult, + _cmdletPassedIn); + return null; + } - _cmdletPassedIn.WriteVerbose($"Downloading blob for {packageNameLowercase} - {packageVersion}"); - HttpContent responseContent; - try - { - responseContent = GetContainerRegistryBlobAsync(packageNameLowercase, digest, containerRegistryAccessToken).Result; + // The first layer contains the nupkg + var nupkgLayer = manifest.Layers[0]; + _cmdletPassedIn.WriteVerbose($"Downloading blob for {packageNameLowercase} - {packageVersion} (digest: {nupkgLayer.Digest})"); + + // Fetch the blob content + using var blobStream = repo.FetchAsync(nupkgLayer).GetAwaiter().GetResult(); + var resultStream = new MemoryStream(); + blobStream.CopyTo(resultStream); + resultStream.Position = 0; + return resultStream; } catch (Exception e) { errRecord = new ErrorRecord( exception: e, - "InstallVersionGetContainerRegistryBlobAsyncError", + "InstallVersionOrasError", ErrorCategory.InvalidResult, _cmdletPassedIn); return null; } - - return responseContent.ReadAsStreamAsync().Result; } #endregion - #region Authentication and Token Methods + #region ORAS Helper Methods /// - /// Gets the access token for the container registry by following the below logic: - /// If a credential is provided when registering the repository, retrieve the token from SecretsManagement. - /// If no credential provided at registration then, check if the ACR endpoint can be accessed without a token. If not, try using Azure.Identity to get the az access token, then ACR refresh token and then ACR access token. - /// Note: Access token can be empty if the repository is unauthenticated + /// Creates an ORAS Repository instance for the given package name. /// - internal string GetContainerRegistryAccessToken(bool needCatalogAccess, bool isPushOperation, out ErrorRecord errRecord) + private OrasRepository CreateOrasRepository(string packageName) { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetContainerRegistryAccessToken()"); - string accessToken = string.Empty; - string containerRegistryAccessToken = string.Empty; - string tenantID = string.Empty; - errRecord = null; - - if (!string.IsNullOrEmpty(_cachedContainterRegistryToken)) + string reference = $"{Registry}/{packageName}"; + return new OrasRepository(new RepositoryOptions { - _cmdletPassedIn.WriteVerbose("Using cached container registry access token."); - return _cachedContainterRegistryToken; - } + Reference = Reference.Parse(reference), + Client = _orasClient, + }); + } - var repositoryCredentialInfo = Repository.CredentialInfo; - if (repositoryCredentialInfo != null) - { - accessToken = Utils.GetContainerRegistryAccessTokenFromSecretManagement( - Repository.Name, - repositoryCredentialInfo, - _cmdletPassedIn); + /// + /// Creates an ORAS Registry instance for catalog operations. + /// + private OrasRegistry CreateOrasRegistry() + { + return new OrasRegistry(Registry, _orasClient); + } - _cmdletPassedIn.WriteVerbose("Access token retrieved."); + /// + /// Lists all tags for a given package using ORAS. + /// + internal List ListImageTags(string packageName, out ErrorRecord errRecord) + { + _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::ListImageTags()"); + errRecord = null; + var tags = new List(); - tenantID = repositoryCredentialInfo.SecretName; - } - else + try { - // A container registry repository is determined to be unauthenticated if it allows anonymous pull access. However, push operations always require authentication. - bool isRepositoryUnauthenticated = isPushOperation ? false : IsContainerRegistryUnauthenticated(Repository.Uri.ToString(), needCatalogAccess, out errRecord, out accessToken); - _cmdletPassedIn.WriteInformation($"Value of isRepositoryUnauthenticated: {isRepositoryUnauthenticated}", new string[] { "PSRGContainerRegistryUnauthenticatedCheck" }); - - _cmdletPassedIn.WriteDebug($"Is repository unauthenticated: {isRepositoryUnauthenticated}"); - - if (errRecord != null) - { - return null; - } - - if (!string.IsNullOrEmpty(accessToken)) - { - _cmdletPassedIn.WriteVerbose("Anonymous access token retrieved."); - return accessToken; - } - - if (!isRepositoryUnauthenticated) + var repo = CreateOrasRepository(packageName); + var tagsEnumerable = repo.ListTagsAsync(""); + // Collect all tags synchronously + var enumerator = tagsEnumerable.GetAsyncEnumerator(); + try { - accessToken = Utils.GetAzAccessToken(_cmdletPassedIn); - if (string.IsNullOrEmpty(accessToken)) + while (enumerator.MoveNextAsync().AsTask().GetAwaiter().GetResult()) { - errRecord = new ErrorRecord( - new InvalidOperationException("Failed to get access token from Azure."), - "AzAccessTokenFailure", - ErrorCategory.AuthenticationError, - this); - - return null; + tags.Add(enumerator.Current); } } - else + finally { - _cmdletPassedIn.WriteVerbose("Repository is unauthenticated"); - return null; + enumerator.DisposeAsync().AsTask().GetAwaiter().GetResult(); } } - - var containerRegistryRefreshToken = GetContainerRegistryRefreshToken(tenantID, accessToken, out errRecord); - if (errRecord != null) - { - return null; - } - - containerRegistryAccessToken = GetContainerRegistryAccessTokenByRefreshToken(containerRegistryRefreshToken, out errRecord); - if (errRecord != null) + catch (Exception e) { - return null; + errRecord = new ErrorRecord( + exception: e, + "ListImageTagsOrasError", + ErrorCategory.InvalidResult, + _cmdletPassedIn); } - _cmdletPassedIn.WriteVerbose("Container registry access token retrieved."); - _cachedContainterRegistryToken = containerRegistryAccessToken; - - return containerRegistryAccessToken; + return tags; } /// - /// Checks if container registry repository is unauthenticated. + /// Lists all repositories in the registry using ORAS. /// - internal bool IsContainerRegistryUnauthenticated(string containerRegistryUrl, bool needCatalogAccess, out ErrorRecord errRecord, out string anonymousAccessToken) + internal List ListAllRepositories(out ErrorRecord errRecord) { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::IsContainerRegistryUnauthenticated()"); + _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::ListAllRepositories()"); errRecord = null; - anonymousAccessToken = string.Empty; - string endpoint = $"{containerRegistryUrl}/v2/"; - HttpResponseMessage response; + var repositories = new List(); + try { - response = _sessionClient.SendAsync(new HttpRequestMessage(HttpMethod.Head, endpoint)).Result; - - if (response.StatusCode == HttpStatusCode.Unauthorized) + var registry = CreateOrasRegistry(); + var repoEnumerable = registry.ListRepositoriesAsync(""); + var enumerator = repoEnumerable.GetAsyncEnumerator(); + try { - // check if there is a auth challenge header - if (response.Headers.WwwAuthenticate.Count() > 0) + while (enumerator.MoveNextAsync().AsTask().GetAwaiter().GetResult()) { - var authHeader = response.Headers.WwwAuthenticate.First(); - if (authHeader.Scheme == "Bearer") - { - // check if there is a realm - if (authHeader.Parameter.Contains("realm")) - { - // get the realm - var realm = authHeader.Parameter.Split(',')?.Where(x => x.Contains("realm"))?.FirstOrDefault()?.Split('=')[1]?.Trim('"'); - // get the service - var service = authHeader.Parameter.Split(',')?.Where(x => x.Contains("service"))?.FirstOrDefault()?.Split('=')[1]?.Trim('"'); - - if (string.IsNullOrEmpty(realm) || string.IsNullOrEmpty(service)) - { - errRecord = new ErrorRecord( - new InvalidOperationException("Failed to get realm or service from the auth challenge header."), - "RegistryUnauthenticationCheckError", - ErrorCategory.InvalidResult, - this); - - return false; - } - - string content = needCatalogAccess ? String.Format(grantTypeTemplate, service, catalogScope) : String.Format(grantTypeTemplate, service, defaultScope); - - var contentHeaders = new Collection> { new KeyValuePair("Content-Type", "application/x-www-form-urlencoded") }; - - string url = needCatalogAccess ? String.Format(authUrlTemplate, realm, service, catalogScope) : String.Format(authUrlTemplate, realm, service, defaultScope); - - _cmdletPassedIn.WriteDebug($"Getting anonymous access token from the realm: {url}"); - - // we don't check the error record here because we want to return false if we get a 401 and not throw an error - _cmdletPassedIn.WriteDebug($"Getting anonymous access token from the realm: {url}"); - ErrorRecord errRecordTemp = null; - - var results = GetHttpResponseJObjectUsingContentHeaders(url, HttpMethod.Get, content, contentHeaders, out errRecordTemp); - - if (results == null) - { - _cmdletPassedIn.WriteDebug("Failed to get access token from the realm. results is null."); - _cmdletPassedIn.WriteDebug($"ErrorRecord: {errRecordTemp}"); - return false; - } - - if (results["access_token"] == null) - { - _cmdletPassedIn.WriteDebug($"Failed to get access token from the realm. access_token is null. results: {results}"); - return false; - } - - anonymousAccessToken = results["access_token"].ToString(); - return true; - } - } + repositories.Add(enumerator.Current); } } - } - catch (HttpRequestException hre) - { - errRecord = new ErrorRecord( - hre, - "RegistryAnonymousAcquireError", - ErrorCategory.ConnectionError, - this); - - return false; + finally + { + enumerator.DisposeAsync().AsTask().GetAwaiter().GetResult(); + } } catch (Exception e) { errRecord = new ErrorRecord( - e, - "RegistryUnauthenticationCheckError", + exception: e, + "ListAllRepositoriesOrasError", ErrorCategory.InvalidResult, - this); - - return false; - } - - return (response.StatusCode == HttpStatusCode.OK); - } - - /// - /// Given the access token retrieved from credentials, gets the refresh token. - /// - internal string GetContainerRegistryRefreshToken(string tenant, string accessToken, out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetContainerRegistryRefreshToken()"); - string content = string.Format(containerRegistryRefreshTokenTemplate, Registry, tenant, accessToken); - var contentHeaders = new Collection> { new KeyValuePair("Content-Type", "application/x-www-form-urlencoded") }; - string exchangeUrl = string.Format(containerRegistryOAuthExchangeUrlTemplate, Registry); - var results = GetHttpResponseJObjectUsingContentHeaders(exchangeUrl, HttpMethod.Post, content, contentHeaders, out errRecord); - if (errRecord != null || results == null || results["refresh_token"] == null) - { - return string.Empty; - } - - return results["refresh_token"].ToString(); - } - - /// - /// Given the refresh token, gets the new access token with appropriate scope access permissions. - /// - internal string GetContainerRegistryAccessTokenByRefreshToken(string refreshToken, out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetContainerRegistryAccessTokenByRefreshToken()"); - string content = string.Format(containerRegistryAccessTokenTemplate, Registry, refreshToken); - var contentHeaders = new Collection> { new KeyValuePair("Content-Type", "application/x-www-form-urlencoded") }; - string tokenUrl = string.Format(containerRegistryOAuthTokenUrlTemplate, Registry); - var results = GetHttpResponseJObjectUsingContentHeaders(tokenUrl, HttpMethod.Post, content, contentHeaders, out errRecord); - if (errRecord != null || results == null || results["access_token"] == null) - { - return string.Empty; + _cmdletPassedIn); } - return results["access_token"].ToString(); + return repositories; } - #endregion - - #region Private Methods - /// - /// Parses package manifest JObject to find digest entry, which is the SHA needed to identify and get the package. + /// Fetches the manifest for a specific package version and parses it. + /// Returns the manifest as a parsed OCI Manifest object. /// - private string GetDigestFromManifest(JObject manifest, out ErrorRecord errRecord) + internal Manifest FetchManifest(string packageName, string version, out ErrorRecord errRecord) { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetDigestFromManifest()"); + _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::FetchManifest()"); errRecord = null; - string digest = String.Empty; - if (manifest == null) + try { - errRecord = new ErrorRecord( - exception: new ArgumentNullException("Manifest (passed in to determine digest) is null."), - "ManifestNullError", - ErrorCategory.InvalidArgument, - _cmdletPassedIn); + var repo = CreateOrasRepository(packageName); + var (descriptor, stream) = repo.FetchAsync(version).GetAwaiter().GetResult(); + byte[] manifestBytes; + using (stream) + { + manifestBytes = stream.ReadAllAsync(descriptor).GetAwaiter().GetResult(); + } - return digest; + var manifest = System.Text.Json.JsonSerializer.Deserialize(manifestBytes); + return manifest; } - - JToken layers = manifest["layers"]; - if (layers == null || !layers.HasValues) + catch (Exception e) { errRecord = new ErrorRecord( - exception: new ArgumentNullException("Manifest 'layers' property (passed in to determine digest) is null or does not have values."), - "ManifestLayersNullOrEmptyError", - ErrorCategory.InvalidArgument, + exception: e, + "FetchManifestOrasError", + ErrorCategory.InvalidResult, _cmdletPassedIn); - return digest; - } - - foreach (JObject item in layers) - { - if (item.ContainsKey("digest")) - { - digest = item.GetValue("digest").ToString(); - break; - } + return null; } - - return digest; - } - - /// - /// Gets the manifest for a package (ie repository in container registry terms) from the repository (ie registry in container registry terms) - /// - internal JObject GetContainerRegistryRepositoryManifest(string packageName, string version, string containerRegistryAccessToken, out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetContainerRegistryRepositoryManifest()"); - // example of manifestUrl: https://psgetregistry.azurecr.io/hello-world:3.0.0 - string manifestUrl = string.Format(containerRegistryManifestUrlTemplate, Registry, packageName, version); - var defaultHeaders = GetDefaultHeaders(containerRegistryAccessToken); - return GetHttpResponseJObjectUsingDefaultHeaders(manifestUrl, HttpMethod.Get, defaultHeaders, out errRecord); - } - - /// - /// Get the blob for the package (ie repository in container registry terms) from the repository (ie registry in container registry terms) - /// Used when installing the package - /// - internal async Task GetContainerRegistryBlobAsync(string packageName, string digest, string containerRegistryAccessToken) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetContainerRegistryBlobAsync()"); - string blobUrl = string.Format(containerRegistryBlobDownloadUrlTemplate, Registry, packageName, digest); - var defaultHeaders = GetDefaultHeaders(containerRegistryAccessToken); - return await GetHttpContentResponseJObject(blobUrl, defaultHeaders); - } - - /// - /// Gets the image tags associated with the package (i.e repository in container registry terms), where the tag corresponds to the package's versions. - /// If the package version is specified search for that specific tag for the image, if the package version is "*" search for all tags for the image. - /// - internal JObject FindContainerRegistryImageTags(string packageName, string version, string containerRegistryAccessToken, out ErrorRecord errRecord) - { - /* - { - "name": "", - "tags": [ - "", - "", - "" - ] - } - */ - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::FindContainerRegistryImageTags()"); - string resolvedVersion = string.Equals(version, "*", StringComparison.OrdinalIgnoreCase) ? null : $"/{version}"; - string findImageUrl = string.Format(containerRegistryFindImageVersionUrlTemplate, Registry, packageName); - var defaultHeaders = GetDefaultHeaders(containerRegistryAccessToken); - return GetHttpResponseJObjectUsingDefaultHeaders(findImageUrl, HttpMethod.Get, defaultHeaders, out errRecord); - } - - /// - /// Helper method to find all packages on container registry - /// - /// - /// - /// - internal JObject FindAllRepositories(string containerRegistryAccessToken, out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::FindAllRepositories()"); - string repositoryListUrl = string.Format(containerRegistryRepositoryListTemplate, Registry); - var defaultHeaders = GetDefaultHeaders(containerRegistryAccessToken); - return GetHttpResponseJObjectUsingDefaultHeaders(repositoryListUrl, HttpMethod.Get, defaultHeaders, out errRecord, usePagination: true); } /// - /// Get metadata for a package version. + /// Get metadata for a package version by fetching its manifest and reading annotations. /// - internal Hashtable GetContainerRegistryMetadata(string packageName, string exactTagVersion, string containerRegistryAccessToken, out ErrorRecord errRecord) + internal Hashtable GetContainerRegistryMetadata(string packageName, string exactTagVersion, out ErrorRecord errRecord) { _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetContainerRegistryMetadata()"); Hashtable requiredVersionResponse = new(); - JObject foundTags = FindContainerRegistryManifest(packageName, exactTagVersion, containerRegistryAccessToken, out errRecord); + var manifest = FetchManifest(packageName, exactTagVersion, out errRecord); if (errRecord != null) { return requiredVersionResponse; } - /* - Response returned looks something like: - { - "schemaVersion": 2, - "config": { - "mediaType": "application/vnd.unknown.config.v1+json", - "digest": "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - "size": 0 - }, - "layers": [ - { - "mediaType": "application/vnd.oci.image.layer.nondistributable.v1.tar+gzip'", - "digest": "sha256:7c55c7b66cb075628660d8249cc4866f16e34741c246a42ed97fb23ccd4ea956", - "size": 3533, - "annotations": { - "org.opencontainers.image.title": "test_module.1.0.0.nupkg", - "metadata": "{\"GUID\":\"45219bf4-10a4-4242-92d6-9bfcf79878fd\",\"FunctionsToExport\":[],\"CompanyName\":\"Anam\",\"CmdletsToExport\":[],\"VariablesToExport\":\"*\",\"Author\":\"Anam Navied\",\"ModuleVersion\":\"1.0.0\",\"Copyright\":\"(c) Anam Navied. All rights reserved.\",\"PrivateData\":{\"PSData\":{\"Tags\":[\"Test\",\"CommandsAndResource\",\"Tag2\"]}},\"RequiredModules\":[],\"Description\":\"This is a test module, for PSGallery team internal testing. Do not take a dependency on this package. This version contains tags for the package.\",\"AliasesToExport\":[]}" - } - } - ] - } - */ - - ContainerRegistryInfo serverPkgInfo = GetMetadataProperty(foundTags, packageName, out errRecord); + ContainerRegistryInfo serverPkgInfo = GetMetadataProperty(manifest, packageName, out errRecord); if (errRecord != null) { return requiredVersionResponse; @@ -828,29 +588,15 @@ internal Hashtable GetContainerRegistryMetadata(string packageName, string exact } /// - /// Get the manifest associated with the package version. - /// - internal JObject FindContainerRegistryManifest(string packageName, string version, string containerRegistryAccessToken, out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::FindContainerRegistryManifest()"); - var createManifestUrl = string.Format(containerRegistryManifestUrlTemplate, Registry, packageName, version); - _cmdletPassedIn.WriteDebug($"GET manifest url: {createManifestUrl}"); - - var defaultHeaders = GetDefaultHeaders(containerRegistryAccessToken); - return GetHttpResponseJObjectUsingDefaultHeaders(createManifestUrl, HttpMethod.Get, defaultHeaders, out errRecord); - } - - /// - /// Get metadata for the package by parsing its manifest. + /// Get metadata for the package by parsing its OCI manifest annotations. /// - internal ContainerRegistryInfo GetMetadataProperty(JObject foundTags, string packageName, out ErrorRecord errRecord) + internal ContainerRegistryInfo GetMetadataProperty(Manifest manifest, string packageName, out ErrorRecord errRecord) { _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetMetadataProperty()"); errRecord = null; ContainerRegistryInfo serverPkgInfo = null; - JToken layers = foundTags["layers"]; - if (layers == null || layers[0] == null) + if (manifest == null || manifest.Layers == null || manifest.Layers.Count == 0 || manifest.Layers[0] == null) { errRecord = new ErrorRecord( new InvalidOrEmptyResponse($"Response does not contain 'layers' element in manifest for package '{packageName}' in '{Repository.Name}'."), @@ -861,7 +607,7 @@ internal ContainerRegistryInfo GetMetadataProperty(JObject foundTags, string pac return serverPkgInfo; } - JToken annotations = layers[0]["annotations"]; + var annotations = manifest.Layers[0].Annotations; if (annotations == null) { errRecord = new ErrorRecord( @@ -874,11 +620,10 @@ internal ContainerRegistryInfo GetMetadataProperty(JObject foundTags, string pac } // Check for package name - JToken pkgTitleJToken = annotations["org.opencontainers.image.title"]; - if (pkgTitleJToken == null) + if (!annotations.TryGetValue("org.opencontainers.image.title", out string metadataPkgName) || string.IsNullOrWhiteSpace(metadataPkgName)) { errRecord = new ErrorRecord( - new InvalidOrEmptyResponse($"Response does not contain 'org.opencontainers.image.title' element for package '{packageName}' in '{Repository.Name}'."), + new InvalidOrEmptyResponse($"Response does not contain or has empty 'org.opencontainers.image.title' element for package '{packageName}' in '{Repository.Name}'."), "GetMetadataPropertyOCITitleError", ErrorCategory.InvalidData, this); @@ -886,21 +631,8 @@ internal ContainerRegistryInfo GetMetadataProperty(JObject foundTags, string pac return serverPkgInfo; } - string metadataPkgName = pkgTitleJToken.ToString(); - if (string.IsNullOrWhiteSpace(metadataPkgName)) - { - errRecord = new ErrorRecord( - new InvalidOrEmptyResponse($"Response element 'org.opencontainers.image.title' is empty for package '{packageName}' in '{Repository.Name}'."), - "GetMetadataPropertyOCITitleEmptyError", - ErrorCategory.InvalidData, - this); - - return serverPkgInfo; - } - // Check for package metadata - JToken pkgMetadataJToken = annotations["metadata"]; - if (pkgMetadataJToken == null) + if (!annotations.TryGetValue("metadata", out string metadata) || metadata == null) { errRecord = new ErrorRecord( new InvalidOrEmptyResponse($"Response does not contain 'metadata' element in manifest for package '{packageName}' in '{Repository.Name}'."), @@ -911,801 +643,110 @@ internal ContainerRegistryInfo GetMetadataProperty(JObject foundTags, string pac return serverPkgInfo; } - string metadata = pkgMetadataJToken.ToString(); - // Check for package artifact type - JToken resourceTypeJToken = annotations["resourceType"]; - var resourceType = resourceTypeJToken != null ? resourceTypeJToken.ToString() : "None"; + annotations.TryGetValue("resourceType", out string resourceType); + resourceType = resourceType ?? "None"; return new ContainerRegistryInfo(metadataPkgName, metadata, resourceType); } - /// - /// Upload manifest for the package, used for publishing. - /// - internal async Task UploadManifest(string packageName, string packageVersion, string configPath, bool isManifest, string containerRegistryAccessToken) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::UploadManifest()"); - try - { - var createManifestUrl = string.Format(containerRegistryManifestUrlTemplate, Registry, packageName, packageVersion); - var defaultHeaders = GetDefaultHeaders(containerRegistryAccessToken); - return await PutRequestAsync(createManifestUrl, configPath, isManifest, defaultHeaders); - } - catch (HttpRequestException e) - { - throw new HttpRequestException("Error occurred while trying to create manifest: " + e.Message); - } - } + #endregion - internal async Task GetHttpContentResponseJObject(string url, Collection> defaultHeaders) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetHttpContentResponseJObject()"); - try - { - HttpRequestMessage request = new HttpRequestMessage(HttpMethod.Get, url); - SetDefaultHeaders(defaultHeaders); - return await SendContentRequestAsync(request); - } - catch (HttpRequestException e) - { - throw new HttpRequestException("Error occurred while trying to retrieve response: " + e.Message); - } - } + #region Publish Methods /// - /// Get response object when using default headers in the request. + /// Helper method that publishes a package to the container registry. + /// This gets called from Publish-PSResource. /// - internal JObject GetHttpResponseJObjectUsingDefaultHeaders(string url, HttpMethod method, Collection> defaultHeaders, out ErrorRecord errRecord, bool usePagination = false) + internal bool PushNupkgContainerRegistry( + string outputNupkgDir, + string packageName, + string modulePrefix, + NuGetVersion packageVersion, + ResourceType resourceType, + Hashtable parsedMetadataHash, + Hashtable dependencies, + bool isNupkgPathSpecified, + string originalNupkgPath, + out ErrorRecord errRecord) { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetHttpResponseJObjectUsingDefaultHeaders()"); + _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::PushNupkgContainerRegistry()"); + errRecord = null; + + // if isNupkgPathSpecified, then we need to publish the original .nupkg file, as it may be signed + string fullNupkgFile = isNupkgPathSpecified ? originalNupkgPath : System.IO.Path.Combine(outputNupkgDir, packageName + "." + packageVersion.ToNormalizedString() + ".nupkg"); + + string pkgNameForUpload = string.IsNullOrEmpty(modulePrefix) ? packageName : modulePrefix + "/" + packageName; + string packageNameLowercase = pkgNameForUpload.ToLower(); + try { - errRecord = null; - HttpRequestMessage request = new HttpRequestMessage(method, url); - SetDefaultHeaders(defaultHeaders); + var repo = CreateOrasRepository(packageNameLowercase); - var response = usePagination ? SendRequestAsyncWithPagination(request) : SendRequestAsync(request); - return response.GetAwaiter().GetResult(); - } - catch (ResourceNotFoundException e) - { - errRecord = new ErrorRecord( - exception: e, - "ResourceNotFound", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - } - catch (UnauthorizedException e) - { - errRecord = new ErrorRecord( - exception: e, - "UnauthorizedRequest", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - } - catch (HttpRequestException e) - { - errRecord = new ErrorRecord( - exception: e, - "HttpRequestCallFailure", - ErrorCategory.InvalidResult, - _cmdletPassedIn); + // Read the nupkg file bytes + _cmdletPassedIn.WriteVerbose($"Reading .nupkg file: {fullNupkgFile}"); + byte[] nupkgBytes = File.ReadAllBytes(fullNupkgFile); + + // Create metadata JSON string + _cmdletPassedIn.WriteVerbose("Create package version metadata as JSON string"); + string metadataJson = CreateMetadataContent(resourceType, parsedMetadataHash, out errRecord); + if (errRecord != null) + { + return false; + } + + var fileName = System.IO.Path.GetFileName(fullNupkgFile); + + // Create layer descriptor for the nupkg with annotations + var nupkgDescriptor = Descriptor.Create(nupkgBytes, OrasProject.Oras.Oci.MediaType.ImageLayerGzip); + nupkgDescriptor.Annotations = new Dictionary + { + ["org.opencontainers.image.title"] = packageName, + ["org.opencontainers.image.description"] = fileName, + ["metadata"] = metadataJson, + ["resourceType"] = resourceType.ToString() + }; + + // Push the nupkg layer + _cmdletPassedIn.WriteVerbose($"Pushing .nupkg blob for {packageNameLowercase}"); + repo.PushAsync(nupkgDescriptor, new MemoryStream(nupkgBytes)).GetAwaiter().GetResult(); + + // Create config descriptor + byte[] configBytes = Array.Empty(); + var configDescriptor = Descriptor.Create(configBytes, OrasProject.Oras.Oci.MediaType.ImageConfig); + + // Pack and push the manifest using Packer + _cmdletPassedIn.WriteVerbose("Packing and pushing manifest"); + var packOptions = new PackManifestOptions + { + Config = configDescriptor, + Layers = new List { nupkgDescriptor } + }; + + var manifestDescriptor = Packer.PackManifestAsync(repo, Packer.ManifestVersion.Version1_1, "", packOptions).GetAwaiter().GetResult(); + + // Tag the manifest with the version + _cmdletPassedIn.WriteVerbose($"Tagging manifest with version: {packageVersion.OriginalVersion}"); + repo.TagAsync(manifestDescriptor, packageVersion.OriginalVersion).GetAwaiter().GetResult(); } catch (Exception e) { errRecord = new ErrorRecord( - exception: e, - "HttpRequestCallFailure", + new UploadBlobException($"Error occurred while publishing package to ContainerRegistry: {e.GetType()} '{e.Message}'", e), + "PackagePublishOrasError", ErrorCategory.InvalidResult, _cmdletPassedIn); + + return false; } - return null; + return true; } /// - /// Get response object when using content headers in the request. + /// Create metadata for the package that will be populated in the manifest. /// - internal JObject GetHttpResponseJObjectUsingContentHeaders(string url, HttpMethod method, string content, Collection> contentHeaders, out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetHttpResponseJObjectUsingContentHeaders()"); - errRecord = null; - try - { - HttpRequestMessage request = new HttpRequestMessage(method, url); - - // HTTP GET does not expect a body / content. - if (method != HttpMethod.Get) - { - - if (string.IsNullOrEmpty(content)) - { - errRecord = new ErrorRecord( - exception: new ArgumentNullException($"Content is null or empty and cannot be used to make a request as its content headers."), - "RequestContentHeadersNullOrEmpty", - ErrorCategory.InvalidData, - _cmdletPassedIn); - - return null; - } - - // codeql[cs/sensitive-data-transmission] This is expected PSResourceGet behavior to create the content of the request which is only transmitted to the server, not the user. This information is also not exposed back to the user via error or verbose messaging. - request.Content = new StringContent(content); - request.Content.Headers.Clear(); - if (contentHeaders != null) - { - foreach (var header in contentHeaders) - { - request.Content.Headers.Add(header.Key, header.Value); - } - } - } - - return SendRequestAsync(request).GetAwaiter().GetResult(); - } - catch (ResourceNotFoundException e) - { - errRecord = new ErrorRecord( - exception: e, - "ResourceNotFound", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - } - catch (UnauthorizedException e) - { - errRecord = new ErrorRecord( - exception: e, - "UnauthorizedRequest", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - } - catch (HttpRequestException e) - { - errRecord = new ErrorRecord( - exception: e, - "HttpRequestCallFailure", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - } - catch (Exception e) - { - errRecord = new ErrorRecord( - exception: e, - "HttpRequestCallFailure", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - } - - return null; - } - - /// - /// Get response headers. - /// - internal async Task GetHttpResponseHeader(string url, HttpMethod method, Collection> defaultHeaders) - { - try - { - HttpRequestMessage request = new HttpRequestMessage(method, url); - SetDefaultHeaders(defaultHeaders); - return await SendRequestHeaderAsync(request); - } - catch (HttpRequestException e) - { - throw new HttpRequestException("Error occurred while trying to retrieve response header: " + e.Message); - } - } - - /// - /// Set default headers for HttpClient. - /// - private void SetDefaultHeaders(Collection> defaultHeaders) - { - _sessionClient.DefaultRequestHeaders.Clear(); - if (defaultHeaders != null) - { - foreach (var header in defaultHeaders) - { - if (string.Equals(header.Key, "Authorization", StringComparison.OrdinalIgnoreCase)) - { - _sessionClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", header.Value); - } - else if (string.Equals(header.Key, "Accept", StringComparison.OrdinalIgnoreCase)) - { - _sessionClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue(header.Value)); - } - else - { - _sessionClient.DefaultRequestHeaders.Add(header.Key, header.Value); - } - } - } - } - - /// - /// Sends request for content. - /// - private async Task SendContentRequestAsync(HttpRequestMessage message) - { - try - { - HttpResponseMessage response = await _sessionClient.SendAsync(message); - response.EnsureSuccessStatusCode(); - return response.Content; - } - catch (Exception e) - { - throw new SendRequestException($"Error occurred while sending request to Container Registry server for content with: {e.GetType()} '{e.Message}'", e); - } - } - - /// - /// Sends HTTP request. - /// - private async Task SendRequestAsync(HttpRequestMessage message) - { - HttpResponseMessage response; - try - { - response = await _sessionClient.SendAsync(message); - } - catch (Exception e) - { - throw new SendRequestException($"Error occurred while sending request to Container Registry server with: {e.GetType()} '{e.Message}'", e); - } - - switch (response.StatusCode) - { - case HttpStatusCode.OK: - break; - - case HttpStatusCode.Unauthorized: - throw new UnauthorizedException($"Response returned status code: {response.ReasonPhrase}."); - - case HttpStatusCode.NotFound: - throw new ResourceNotFoundException($"Response returned status code package: {response.ReasonPhrase}."); - - default: - throw new Exception($"Response returned error with status code {response.StatusCode}: {response.ReasonPhrase}."); - } - - return JsonConvert.DeserializeObject(await response.Content.ReadAsStringAsync()); - } - - private async Task SendRequestAsyncWithPagination(HttpRequestMessage initialMessage) - { - HttpResponseMessage response; - string nextUrl = initialMessage.RequestUri.ToString(); - string urlBase = initialMessage.RequestUri.Scheme + "://" + initialMessage.RequestUri.Host; - JObject finalResult = new JObject(); - JArray allRepositories = new JArray(); - - do - { - var message = new HttpRequestMessage(HttpMethod.Get, nextUrl); - try - { - response = await _sessionClient.SendAsync(message); - } - catch (Exception e) - { - throw new SendRequestException($"Error occurred while sending request to Container Registry server with: {e.GetType()} '{e.Message}'", e); - } - - switch (response.StatusCode) - { - case HttpStatusCode.OK: - break; - case HttpStatusCode.Unauthorized: - throw new UnauthorizedException($"Response returned status code: {response.ReasonPhrase}."); - case HttpStatusCode.NotFound: - throw new ResourceNotFoundException($"Response returned status code package: {response.ReasonPhrase}."); - default: - throw new Exception($"Response returned error with status code {response.StatusCode}: {response.ReasonPhrase}."); - } - - var content = await response.Content.ReadAsStringAsync(); - var json = JObject.Parse(content); - var repositories = json["repositories"] as JArray; - if (repositories != null) - { - allRepositories.Merge(repositories); - } - - // Check for Link header to continue pagination - if (response.Headers.TryGetValues("Link", out var linkHeaders)) - { - var linkHeader = string.Join(",", linkHeaders); - var match = Regex.Match(linkHeader, @"<([^>]+)>;\s*rel=""next"""); - var nextUrlPart = match.Success ? match.Groups[1].Value : null; - if (!string.IsNullOrEmpty(nextUrlPart)) - { - nextUrl = urlBase + nextUrlPart; - } - else - { - nextUrl = null; - } - } - else - { - nextUrl = null; - } - - } while (!string.IsNullOrEmpty(nextUrl)); - - finalResult["repositories"] = allRepositories; - return finalResult; - } - - /// - /// Send request to get response headers. - /// - private async Task SendRequestHeaderAsync(HttpRequestMessage message) - { - try - { - HttpResponseMessage response = await _sessionClient.SendAsync(message); - response.EnsureSuccessStatusCode(); - return response.Headers; - } - catch (HttpRequestException e) - { - throw new HttpRequestException("Error occurred while trying to retrieve response: " + e.Message); - } - } - - /// - /// Sends a PUT request, used for publishing to container registry. - /// - private async Task PutRequestAsync(string url, string filePath, bool isManifest, Collection> contentHeaders) - { - try - { - SetDefaultHeaders(contentHeaders); - - FileInfo fileInfo = new FileInfo(filePath); - using (FileStream fileStream = fileInfo.Open(FileMode.Open, FileAccess.Read)) - { - HttpContent httpContent = new StreamContent(fileStream); - if (isManifest) - { - httpContent.Headers.Add("Content-Type", "application/vnd.oci.image.manifest.v1+json"); - } - else - { - httpContent.Headers.Add("Content-Type", "application/octet-stream"); - } - - return await _sessionClient.PutAsync(url, httpContent); - } - } - catch (Exception e) - { - throw new SendRequestException($"Error occurred while uploading module to ContainerRegistry: {e.GetType()} '{e.Message}'", e); - } - } - - /// - /// Get the default headers associated with the access token. - /// - private static Collection> GetDefaultHeaders(string containerRegistryAccessToken) - { - var defaultHeaders = new Collection>(); - - if (!string.IsNullOrEmpty(containerRegistryAccessToken)) - { - defaultHeaders.Add(new KeyValuePair("Authorization", containerRegistryAccessToken)); - } - - defaultHeaders.Add(new KeyValuePair("Accept", "application/vnd.oci.image.manifest.v1+json")); - - return defaultHeaders; - } - - #endregion - - #region Publish Methods - /// - /// Helper method that publishes a package to the container registry. - /// This gets called from Publish-PSResource. - /// - internal bool PushNupkgContainerRegistry( - string outputNupkgDir, - string packageName, - string modulePrefix, - NuGetVersion packageVersion, - ResourceType resourceType, - Hashtable parsedMetadataHash, - Hashtable dependencies, - bool isNupkgPathSpecified, - string originalNupkgPath, - out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::PushNupkgContainerRegistry()"); - - // if isNupkgPathSpecified, then we need to publish the original .nupkg file, as it may be signed - string fullNupkgFile = isNupkgPathSpecified ? originalNupkgPath : System.IO.Path.Combine(outputNupkgDir, packageName + "." + packageVersion.ToNormalizedString() + ".nupkg"); - - string pkgNameForUpload = string.IsNullOrEmpty(modulePrefix) ? packageName : modulePrefix + "/" + packageName; - string packageNameLowercase = pkgNameForUpload.ToLower(); - - // Get access token (includes refresh tokens) - _cmdletPassedIn.WriteVerbose($"Get access token for container registry server."); - var containerRegistryAccessToken = GetContainerRegistryAccessToken(needCatalogAccess: false, isPushOperation: true, out errRecord); - if (errRecord != null) - { - return false; - } - - // Upload .nupkg - _cmdletPassedIn.WriteVerbose($"Upload .nupkg file: {fullNupkgFile}"); - string nupkgDigest = UploadNupkgFile(packageNameLowercase, containerRegistryAccessToken, fullNupkgFile, out errRecord); - if (errRecord != null) - { - return false; - } - - // Create and upload an empty file-- needed by ContainerRegistry server - CreateAndUploadEmptyFile(outputNupkgDir, packageNameLowercase, containerRegistryAccessToken, out errRecord); - if (errRecord != null) - { - return false; - } - - // Create config.json file - var configFilePath = System.IO.Path.Combine(outputNupkgDir, "config.json"); - _cmdletPassedIn.WriteVerbose($"Create config.json file at path: {configFilePath}"); - string configDigest = CreateConfigFile(configFilePath, out errRecord); - if (errRecord != null) - { - return false; - } - - _cmdletPassedIn.WriteVerbose("Create package version metadata as JSON string"); - // Create module metadata string - string metadataJson = CreateMetadataContent(resourceType, parsedMetadataHash, out errRecord); - if (errRecord != null) - { - return false; - } - - // Create and upload manifest - TryCreateAndUploadManifest(fullNupkgFile, nupkgDigest, configDigest, packageName, modulePrefix, resourceType, metadataJson, configFilePath, packageVersion, containerRegistryAccessToken, out errRecord); - if (errRecord != null) - { - return false; - } - - return true; - } - - /// - /// Upload the nupkg file, by creating a digest for it and uploading as blob. - /// Note: ContainerRegistry registries will only accept a name that is all lowercase. - /// - private string UploadNupkgFile(string packageNameLowercase, string containerRegistryAccessToken, string fullNupkgFile, out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::UploadNupkgFile()"); - _cmdletPassedIn.WriteVerbose("Start uploading blob"); - string nupkgDigest = string.Empty; - errRecord = null; - string moduleLocation; - try - { - moduleLocation = GetStartUploadBlobLocation(packageNameLowercase, containerRegistryAccessToken).Result; - } - catch (Exception startUploadException) - { - errRecord = new ErrorRecord( - startUploadException, - "StartUploadBlobLocationError", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - - return nupkgDigest; - } - - _cmdletPassedIn.WriteVerbose("Computing digest for .nupkg file"); - nupkgDigest = CreateDigest(fullNupkgFile, out errRecord); - if (errRecord != null) - { - return nupkgDigest; - } - - _cmdletPassedIn.WriteVerbose("Finish uploading blob"); - try - { - var responseNupkg = EndUploadBlob(moduleLocation, fullNupkgFile, nupkgDigest, isManifest: false, containerRegistryAccessToken).Result; - bool uploadSuccessful = responseNupkg.IsSuccessStatusCode; - - if (!uploadSuccessful) - { - errRecord = new ErrorRecord( - new UploadBlobException("Uploading of blob for publish failed."), - "EndUploadBlobError", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - - return nupkgDigest; - } - } - catch (Exception endUploadException) - { - errRecord = new ErrorRecord( - endUploadException, - "EndUploadBlobError", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - - return nupkgDigest; - } - - return nupkgDigest; - } - - /// - /// Uploads an empty file at the start of publish as is needed. - /// - private void CreateAndUploadEmptyFile(string outputNupkgDir, string pkgNameLower, string containerRegistryAccessToken, out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::CreateAndUploadEmptyFile()"); - _cmdletPassedIn.WriteVerbose("Create an empty file"); - string emptyFileName = "empty" + Guid.NewGuid().ToString() + ".txt"; - var emptyFilePath = System.IO.Path.Combine(outputNupkgDir, emptyFileName); - - try - { - Utils.CreateFile(emptyFilePath); - - _cmdletPassedIn.WriteVerbose("Start uploading an empty file"); - string emptyLocation = GetStartUploadBlobLocation(pkgNameLower, containerRegistryAccessToken).Result; - - _cmdletPassedIn.WriteVerbose("Computing digest for empty file"); - string emptyFileDigest = CreateDigest(emptyFilePath, out errRecord); - if (errRecord != null) - { - return; - } - - _cmdletPassedIn.WriteVerbose("Finish uploading empty file"); - var emptyResponse = EndUploadBlob(emptyLocation, emptyFilePath, emptyFileDigest, false, containerRegistryAccessToken).Result; - bool uploadSuccessful = emptyResponse.IsSuccessStatusCode; - - if (!uploadSuccessful) - { - errRecord = new ErrorRecord( - new UploadBlobException($"Error occurred while uploading blob, response code was: {emptyResponse.StatusCode} with reason {emptyResponse.ReasonPhrase}"), - "UploadEmptyFileError", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - - return; - } - } - catch (Exception e) - { - errRecord = new ErrorRecord( - e, - "UploadEmptyFileError", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - - return; - } - } - - /// - /// Create config file associated with the package (i.e repository in container registry terms) as is needed for the package's manifest config layer - /// - private string CreateConfigFile(string configFilePath, out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::CreateConfigFile()"); - string configFileDigest = string.Empty; - _cmdletPassedIn.WriteVerbose("Create the config file"); - while (File.Exists(configFilePath)) - { - configFilePath = Guid.NewGuid().ToString() + ".json"; - } - - try - { - Utils.CreateFile(configFilePath); - - _cmdletPassedIn.WriteVerbose("Computing digest for config"); - configFileDigest = CreateDigest(configFilePath, out errRecord); - if (errRecord != null) - { - return configFileDigest; - } - } - catch (Exception e) - { - errRecord = new ErrorRecord( - e, - "CreateConfigFileError", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - - return configFileDigest; - } - - return configFileDigest; - } - - /// - /// Create the manifest for the package and upload it - /// - private bool TryCreateAndUploadManifest(string fullNupkgFile, - string nupkgDigest, - string configDigest, - string packageName, - string modulePrefix, - ResourceType resourceType, - string metadataJson, - string configFilePath, - NuGetVersion pkgVersion, - string containerRegistryAccessToken, - out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::TryCreateAndUploadManifest()"); - errRecord = null; - - string pkgNameForUpload = string.IsNullOrEmpty(modulePrefix) ? packageName : modulePrefix + "/" + packageName; - string packageNameLowercase = pkgNameForUpload.ToLower(); - - FileInfo nupkgFile = new FileInfo(fullNupkgFile); - var fileSize = nupkgFile.Length; - var fileName = System.IO.Path.GetFileName(fullNupkgFile); - string fileContent = CreateManifestContent(nupkgDigest, configDigest, fileSize, fileName, packageName, resourceType, metadataJson); - File.WriteAllText(configFilePath, fileContent); - - _cmdletPassedIn.WriteVerbose("Create the manifest layer"); - bool manifestCreated = false; - try - { - HttpResponseMessage manifestResponse = UploadManifest(packageNameLowercase, pkgVersion.OriginalVersion, configFilePath, true, containerRegistryAccessToken).Result; - manifestCreated = manifestResponse.IsSuccessStatusCode; - } - catch (Exception e) - { - errRecord = new ErrorRecord( - new UploadBlobException($"Error occurred while uploading package manifest to ContainerRegistry: {e.GetType()} '{e.Message}'", e), - "PackageManifestUploadError", - ErrorCategory.InvalidResult, - _cmdletPassedIn); - - return manifestCreated; - } - - return manifestCreated; - } - - /// - /// Create the content for the manifest for the packge. - /// - private string CreateManifestContent( - string nupkgDigest, - string configDigest, - long nupkgFileSize, - string fileName, - string packageName, - ResourceType resourceType, - string metadata) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::CreateManifestContent()"); - StringBuilder stringBuilder = new StringBuilder(); - StringWriter stringWriter = new StringWriter(stringBuilder); - JsonTextWriter jsonWriter = new JsonTextWriter(stringWriter); - - jsonWriter.Formatting = Newtonsoft.Json.Formatting.Indented; - - // start of manifest JSON object - jsonWriter.WriteStartObject(); - - jsonWriter.WritePropertyName("schemaVersion"); - jsonWriter.WriteValue(2); - jsonWriter.WritePropertyName("mediaType"); - jsonWriter.WriteValue("application/vnd.oci.image.manifest.v1+json"); - - jsonWriter.WritePropertyName("config"); - jsonWriter.WriteStartObject(); - jsonWriter.WritePropertyName("mediaType"); - jsonWriter.WriteValue("application/vnd.oci.image.config.v1+json"); - jsonWriter.WritePropertyName("digest"); - jsonWriter.WriteValue($"sha256:{configDigest}"); - jsonWriter.WritePropertyName("size"); - jsonWriter.WriteValue(0); - jsonWriter.WriteEndObject(); - - jsonWriter.WritePropertyName("layers"); - jsonWriter.WriteStartArray(); - - jsonWriter.WriteStartObject(); - jsonWriter.WritePropertyName("mediaType"); - jsonWriter.WriteValue("application/vnd.oci.image.layer.v1.tar+gzip"); - jsonWriter.WritePropertyName("digest"); - jsonWriter.WriteValue($"sha256:{nupkgDigest}"); - jsonWriter.WritePropertyName("size"); - jsonWriter.WriteValue(nupkgFileSize); - jsonWriter.WritePropertyName("annotations"); - jsonWriter.WriteStartObject(); - jsonWriter.WritePropertyName("org.opencontainers.image.title"); - jsonWriter.WriteValue(packageName); - jsonWriter.WritePropertyName("org.opencontainers.image.description"); - jsonWriter.WriteValue(fileName); - jsonWriter.WritePropertyName("metadata"); - jsonWriter.WriteValue(metadata); - jsonWriter.WritePropertyName("resourceType"); - jsonWriter.WriteValue(resourceType.ToString()); - jsonWriter.WriteEndObject(); // end of annotations object - - jsonWriter.WriteEndObject(); // end of 'layers' entry object - - jsonWriter.WriteEndArray(); // end of 'layers' array - jsonWriter.WriteEndObject(); // end of manifest JSON object - - return stringWriter.ToString(); - } - - /// - /// Create SHA256 digest that will be associated with .nupkg, config file or empty file. - /// - private string CreateDigest(string fileName, out ErrorRecord errRecord) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::CreateDigest()"); - errRecord = null; - string digest = string.Empty; - FileInfo fileInfo = new FileInfo(fileName); - SHA256 mySHA256 = SHA256.Create(); - - using (FileStream fileStream = fileInfo.Open(FileMode.Open, FileAccess.Read)) - { - try - { - // Create a fileStream for the file. - // Be sure it's positioned to the beginning of the stream. - fileStream.Position = 0; - // Compute the hash of the fileStream. - byte[] hashValue = mySHA256.ComputeHash(fileStream); - StringBuilder stringBuilder = new StringBuilder(); - foreach (byte b in hashValue) - { - stringBuilder.AppendFormat("{0:x2}", b); - } - - digest = stringBuilder.ToString(); - } - catch (IOException ex) - { - errRecord = new ErrorRecord(ex, $"IOException for .nupkg file: {ex.Message}", ErrorCategory.InvalidOperation, null); - return digest; - } - catch (UnauthorizedAccessException ex) - { - errRecord = new ErrorRecord(ex, $"UnauthorizedAccessException for .nupkg file: {ex.Message}", ErrorCategory.PermissionDenied, null); - return digest; - } - catch (Exception ex) - { - errRecord = new ErrorRecord(ex, $"Exception when creating digest: {ex.Message}", ErrorCategory.PermissionDenied, null); - return digest; - } - } - - if (String.IsNullOrEmpty(digest)) - { - errRecord = new ErrorRecord(new ArgumentNullException("Digest created was null or empty."), "DigestNullOrEmptyError.", ErrorCategory.InvalidResult, null); - } - - return digest; - } - - /// - /// Create metadata for the package that will be populated in the manifest. - /// - private string CreateMetadataContent(ResourceType resourceType, Hashtable parsedMetadata, out ErrorRecord errRecord) + private string CreateMetadataContent(ResourceType resourceType, Hashtable parsedMetadata, out ErrorRecord errRecord) { _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::CreateMetadataContent()"); errRecord = null; @@ -1746,42 +787,6 @@ private string CreateMetadataContent(ResourceType resourceType, Hashtable parsed return jsonString; } - /// - /// Get start location when uploading blob, used during publish. - /// - internal async Task GetStartUploadBlobLocation(string packageName, string containerRegistryAccessToken) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetStartUploadBlobLocation()"); - try - { - var defaultHeaders = GetDefaultHeaders(containerRegistryAccessToken); - var startUploadUrl = string.Format(containerRegistryStartUploadTemplate, Registry, packageName); - return (await GetHttpResponseHeader(startUploadUrl, HttpMethod.Post, defaultHeaders)).Location.ToString(); - } - catch (Exception e) - { - throw new UploadBlobException($"Error occurred while starting to upload the blob location used for publishing to ContainerRegistry: {e.GetType()} '{e.Message}'", e); - } - } - - /// - /// Upload blob, used for publishing - /// - internal async Task EndUploadBlob(string location, string filePath, string digest, bool isManifest, string containerRegistryAccessToken) - { - _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::EndUploadBlob()"); - try - { - var endUploadUrl = string.Format(containerRegistryEndUploadTemplate, Registry, location, digest); - var defaultHeaders = GetDefaultHeaders(containerRegistryAccessToken); - return await PutRequestAsync(endUploadUrl, filePath, isManifest, defaultHeaders); - } - catch (Exception e) - { - throw new UploadBlobException($"Error occurred while uploading module to ContainerRegistry: {e.GetType()} '{e.Message}'", e); - } - } - #endregion #region Find Helper Methods @@ -1792,26 +797,17 @@ internal async Task EndUploadBlob(string location, string f private Hashtable[] FindPackagesWithVersionHelper(string packageName, VersionType versionType, VersionRange versionRange, NuGetVersion requiredVersion, bool includePrerelease, bool getOnlyLatest, out ErrorRecord errRecord) { _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::FindPackagesWithVersionHelper()"); - string accessToken = string.Empty; - string tenantID = string.Empty; - string registryUrl = Repository.Uri.ToString(); string packageNameLowercase = packageName.ToLower(); string packageNameForFind = PrependMARPrefix(packageNameLowercase); - string containerRegistryAccessToken = GetContainerRegistryAccessToken(needCatalogAccess: false, isPushOperation: false,out errRecord); - if (errRecord != null) - { - return emptyHashResponses; - } - var foundTags = FindContainerRegistryImageTags(packageNameForFind, "*", containerRegistryAccessToken, out errRecord); - if (errRecord != null || foundTags == null) + var allVersionsList = ListImageTags(packageNameForFind, out errRecord); + if (errRecord != null || allVersionsList == null) { return emptyHashResponses; } List latestVersionResponse = new List(); - List allVersionsList = foundTags["tags"].ToList(); SortedDictionary sortedQualifyingPkgs = GetPackagesWithRequiredVersion(allVersionsList, versionType, versionRange, requiredVersion, packageNameForFind, includePrerelease, out errRecord); if (errRecord != null && sortedQualifyingPkgs?.Count == 0) @@ -1825,7 +821,7 @@ private Hashtable[] FindPackagesWithVersionHelper(string packageName, VersionTyp foreach (var pkgVersionTag in pkgsInDescendingOrder) { string exactTagVersion = pkgVersionTag.Value.ToString(); - Hashtable metadata = GetContainerRegistryMetadata(packageNameForFind, exactTagVersion, containerRegistryAccessToken, out errRecord); + Hashtable metadata = GetContainerRegistryMetadata(packageNameForFind, exactTagVersion, out errRecord); if (errRecord != null || metadata.Count == 0) { return emptyHashResponses; @@ -1845,17 +841,15 @@ private Hashtable[] FindPackagesWithVersionHelper(string packageName, VersionTyp /// /// Helper method used for find scenarios that resolves versions required from all versions found. /// - private SortedDictionary GetPackagesWithRequiredVersion(List allPkgVersions, VersionType versionType, VersionRange versionRange, NuGetVersion specificVersion, string packageName, bool includePrerelease, out ErrorRecord errRecord) + private SortedDictionary GetPackagesWithRequiredVersion(List allPkgVersions, VersionType versionType, VersionRange versionRange, NuGetVersion specificVersion, string packageName, bool includePrerelease, out ErrorRecord errRecord) { _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::GetPackagesWithRequiredVersion()"); errRecord = null; - // we need NuGetVersion to sort versions by order, and string pkgVersionString (which is the exact tag from the server) to call GetContainerRegistryMetadata() later with exact version tag. SortedDictionary sortedPkgs = new SortedDictionary(VersionComparer.Default); bool isSpecificVersionSearch = versionType == VersionType.SpecificVersion; - foreach (var pkgVersionTagInfo in allPkgVersions) + foreach (var pkgVersionString in allPkgVersions) { - string pkgVersionString = pkgVersionTagInfo.ToString(); // determine if the package version that is a repository tag is a valid NuGetVersion if (!NuGetVersion.TryParse(pkgVersionString, out NuGetVersion pkgVersion)) { @@ -1910,13 +904,8 @@ private FindResults FindPackages(string packageName, bool includePrerelease, out { _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::FindPackages()"); errRecord = null; - string containerRegistryAccessToken = GetContainerRegistryAccessToken(needCatalogAccess: true, isPushOperation: false, out errRecord); - if (errRecord != null) - { - return emptyResponseResults; - } - var pkgResult = FindAllRepositories(containerRegistryAccessToken, out errRecord); + var repositoryNames = ListAllRepositories(out errRecord); if (errRecord != null) { return emptyResponseResults; @@ -1926,10 +915,8 @@ private FindResults FindPackages(string packageName, bool includePrerelease, out var isMAR = Repository.IsMARRepository(); // Convert the list of repositories to a list of hashtables - foreach (var repository in pkgResult["repositories"].ToList()) + foreach (var repositoryName in repositoryNames) { - string repositoryName = repository.ToString(); - if (isMAR && !repositoryName.StartsWith(PSRepositoryInfo.MARPrefix)) { continue; diff --git a/src/code/FindHelper.cs b/src/code/FindHelper.cs index 388be9090..e27676de4 100644 --- a/src/code/FindHelper.cs +++ b/src/code/FindHelper.cs @@ -9,6 +9,7 @@ using System.Management.Automation; using System.Net; using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; using System.Text.RegularExpressions; using System.Threading; @@ -188,7 +189,11 @@ public IEnumerable FindByResourceName( { PSRepositoryInfo currentRepository = repositoriesToSearch[i]; - bool isAllowed = GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(currentRepository.Uri); + bool isAllowed = true; + if (System.Runtime.InteropServices.RuntimeInformation.IsOSPlatform(System.Runtime.InteropServices.OSPlatform.Windows)) + { + isAllowed = GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(currentRepository.Uri); + } if (!isAllowed) { @@ -376,7 +381,12 @@ public IEnumerable FindByCommandOrDscResource( { PSRepositoryInfo currentRepository = repositoriesToSearch[i]; - bool isAllowed = GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(currentRepository.Uri); + bool isAllowed = true; + + if (OperatingSystem.IsWindows()) + { + isAllowed = GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(currentRepository.Uri); + } if (!isAllowed) { @@ -583,7 +593,12 @@ public IEnumerable FindByTag( { PSRepositoryInfo currentRepository = repositoriesToSearch[i]; - bool isAllowed = GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(currentRepository.Uri); + + bool isAllowed = true; + if (OperatingSystem.IsWindows()) + { + isAllowed = GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(currentRepository.Uri); + } if (!isAllowed) { @@ -701,7 +716,7 @@ private IEnumerable SearchByNames(ServerApiCall currentServer, R _cmdletPassedIn.WriteDebug("No version specified, package name is '*'"); // Example: Find-PSResource -Name "*" - // Note: Just for resources from V2 servers, specifically PSGallery, if the resource is unlisted and was requested non-explicitly + // Note: Just for resources from V2 servers, specifically PSGallery, if the resource is unlisted and was requested non-explicitly // (i.e requested name has wildcard) the resource should not be returned and ResponseUtil.ConvertToPSResourceResult() call needs to be informed of this. // In all other cases, return the resource regardless of whether it was requested explicitly or not. bool isResourceRequestedWithWildcard = isV2Resource; @@ -752,7 +767,7 @@ private IEnumerable SearchByNames(ServerApiCall currentServer, R // Example: Find-PSResource -Name "Az*" -Tag "Storage" _cmdletPassedIn.WriteDebug("No version specified, package name contains a wildcard."); - // Note: Just for resources from V2 servers, specifically PSGallery, if the resource is unlisted and was requested non-explicitly + // Note: Just for resources from V2 servers, specifically PSGallery, if the resource is unlisted and was requested non-explicitly // (i.e requested name has wildcard) the resource should not be returned and ResponseUtil.ConvertToPSResourceResult() call needs to be informed of this. // In all other cases, return the resource regardless of whether it was requested explicitly or not. bool isResourceRequestedWithWildcard = isV2Resource; @@ -1173,7 +1188,7 @@ internal IEnumerable FindDependencyPackages( } else if(dep.VersionRange.MaxVersion != null && dep.VersionRange.MinVersion != null && dep.VersionRange.MaxVersion.OriginalVersion.Equals(dep.VersionRange.MinVersion.OriginalVersion)) { - string depPkgVersion = dep.VersionRange.MaxVersion.OriginalVersion; + string depPkgVersion = dep.VersionRange.MaxVersion.OriginalVersion; FindResults responses = currentServer.FindVersion(dep.Name, version: dep.VersionRange.MaxVersion.ToNormalizedString(), _type, out ErrorRecord errRecord); if (errRecord != null) { diff --git a/src/code/GroupPolicyRepositoryEnforcement.cs b/src/code/GroupPolicyRepositoryEnforcement.cs index ac4f7ee98..8d0be2f64 100644 --- a/src/code/GroupPolicyRepositoryEnforcement.cs +++ b/src/code/GroupPolicyRepositoryEnforcement.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.Versioning; using Microsoft.PowerShell.PSResourceGet.UtilClasses; using Microsoft.Win32; @@ -14,6 +15,7 @@ namespace Microsoft.PowerShell.PSResourceGet.Cmdlets /// /// This class is used to enforce group policy for repositories. /// + [SupportedOSPlatform("windows")] public class GroupPolicyRepositoryEnforcement { const string userRoot = "HKEY_CURRENT_USER"; @@ -29,6 +31,7 @@ private GroupPolicyRepositoryEnforcement() /// /// /// True if the group policy is enabled, false otherwise. + [SupportedOSPlatform("windows")] public static bool IsGroupPolicyEnabled() { if (Environment.OSVersion.Platform != PlatformID.Win32NT) @@ -57,6 +60,7 @@ public static bool IsGroupPolicyEnabled() /// /// Array of allowed URIs. /// Thrown when the group policy is not enabled. + [SupportedOSPlatform("windows")] public static Uri[]? GetAllowedRepositoryURIs() { if (Environment.OSVersion.Platform != PlatformID.Win32NT) @@ -92,6 +96,7 @@ public static bool IsGroupPolicyEnabled() } } + [SupportedOSPlatform("windows")] internal static bool IsRepositoryAllowed(Uri repositoryUri) { bool isAllowed = false; @@ -113,6 +118,7 @@ internal static bool IsRepositoryAllowed(Uri repositoryUri) return isAllowed; } + [SupportedOSPlatform("windows")] private static List>? ReadGPFromRegistry() { List> allowedRepositories = new List>(); @@ -169,7 +175,13 @@ internal static bool IsRepositoryAllowed(Uri repositoryUri) throw new InvalidOperationException("Invalid registry value."); } - string valueString = value.ToString(); + string? valueString = value.ToString(); + + if (string.IsNullOrEmpty(valueString)) + { + throw new InvalidOperationException("Invalid registry value."); + } + var kvRegValue = ConvertRegValue(valueString); allowedRepositories.Add(kvRegValue); } diff --git a/src/code/InstallHelper.cs b/src/code/InstallHelper.cs index 0616cf040..003e48df8 100644 --- a/src/code/InstallHelper.cs +++ b/src/code/InstallHelper.cs @@ -270,7 +270,12 @@ private List ProcessRepositories( { PSRepositoryInfo currentRepository = listOfRepositories[i]; - bool isAllowed = GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(currentRepository.Uri); + bool isAllowed = true; + + if (OperatingSystem.IsWindows()) + { + isAllowed = GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(currentRepository.Uri); + } if (!isAllowed) { @@ -659,7 +664,7 @@ private List InstallPackages( ErrorCategory.InvalidOperation, _cmdletPassedIn)); - throw e; + throw; } finally { diff --git a/src/code/Microsoft.PowerShell.PSResourceGet.csproj b/src/code/Microsoft.PowerShell.PSResourceGet.csproj index daeaff8e1..7f240c76f 100644 --- a/src/code/Microsoft.PowerShell.PSResourceGet.csproj +++ b/src/code/Microsoft.PowerShell.PSResourceGet.csproj @@ -8,8 +8,8 @@ 1.1.0.1 1.1.0.1 1.1.0.1 - net472 - 9.0 + net8.0 + 12.0 true @@ -21,12 +21,12 @@ - + + - diff --git a/src/code/PSResourceGetCredentialProvider.cs b/src/code/PSResourceGetCredentialProvider.cs new file mode 100644 index 000000000..af5ce5912 --- /dev/null +++ b/src/code/PSResourceGetCredentialProvider.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Management.Automation; +using System.Management.Automation.Runspaces; +using System.Net.Http; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.PowerShell.PSResourceGet.UtilClasses; +using OrasProject.Oras.Registry.Remote.Auth; + +namespace Microsoft.PowerShell.PSResourceGet +{ + /// + /// Implements the ORAS ICredentialProvider interface for PSResourceGet. + /// Handles three authentication pathways: + /// 1. Credentials from SecretManagement vault (CredentialInfo provided) + /// 2. Azure Identity via Utils.GetAzAccessToken (existing helper) + /// 3. Anonymous/unauthenticated access + /// + internal class PSResourceGetCredentialProvider : ICredentialProvider + { + private readonly PSRepositoryInfo _repository; + private readonly PSCmdlet _cmdletPassedIn; + private readonly Runspace _callerRunspace; + private readonly string _registryHost; + private readonly HttpClient _httpClient; + private Credential _cachedCredential; + private DateTimeOffset _tokenExpiry = DateTimeOffset.MinValue; + + // Template for the ACR OAuth2 exchange endpoint + private const string OAuthExchangeUrlTemplate = "https://{0}/oauth2/exchange"; + private const string RefreshTokenRequestBodyTemplate = "grant_type=access_token&service={0}&tenant={1}&access_token={2}"; + private const string RefreshTokenRequestBodyNoTenantTemplate = "grant_type=access_token&service={0}&access_token={1}"; + + internal PSResourceGetCredentialProvider(PSRepositoryInfo repository, PSCmdlet cmdletPassedIn, HttpClient httpClient = null) + { + _repository = repository; + _cmdletPassedIn = cmdletPassedIn; + _callerRunspace = Runspace.DefaultRunspace; + _registryHost = repository.Uri.Host; + _httpClient = httpClient ?? new HttpClient(); + _cachedCredential = new Credential(); + } + + public async Task ResolveCredentialAsync(string hostname, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(hostname)) + { + throw new ArgumentException("Hostname cannot be null or empty.", nameof(hostname)); + } + + // ORAS invokes this callback on a thread pool thread which has no + // PowerShell Runspace. Restore the caller's Runspace so that + // InvokeCommand.InvokeScript, WriteVerbose, WriteWarning and any + // nested PowerShell script invocations (SecretManagement, etc.) work. + var previousRunspace = Runspace.DefaultRunspace; + Runspace.DefaultRunspace = _callerRunspace; + + try + { + return await ResolveCredentialCoreAsync(hostname, cancellationToken).ConfigureAwait(false); + } + finally + { + Runspace.DefaultRunspace = previousRunspace; + } + } + + private async Task ResolveCredentialCoreAsync(string hostname, CancellationToken cancellationToken) + { + // Return cached credential if still valid + if (!string.IsNullOrEmpty(_cachedCredential.RefreshToken) && DateTimeOffset.UtcNow < _tokenExpiry) + { + Utils.WriteVerboseOnCmdlet(_cmdletPassedIn, "Using cached ORAS credential."); + return _cachedCredential; + } + + string aadAccessToken; + string tenantId; + + var repositoryCredentialInfo = _repository.CredentialInfo; + if (repositoryCredentialInfo != null) + { + // Path 1: Credential from SecretsManagement vault + Utils.WriteVerboseOnCmdlet(_cmdletPassedIn, "Retrieving access token from SecretManagement vault."); + aadAccessToken = Utils.GetContainerRegistryAccessTokenFromSecretManagement( + _repository.Name, + repositoryCredentialInfo, + _cmdletPassedIn); + + if (string.IsNullOrEmpty(aadAccessToken)) + { + Utils.WriteWarningOnCmdlet(_cmdletPassedIn, "Failed to retrieve access token from SecretManagement vault."); + return new Credential(); + } + + tenantId = repositoryCredentialInfo.SecretName; + } + else + { + // Path 2: Azure Identity via existing Utils helper + Utils.WriteVerboseOnCmdlet(_cmdletPassedIn, "Acquiring AAD access token via Utils.GetAzAccessToken."); + aadAccessToken = Utils.GetAzAccessToken(_cmdletPassedIn); + + if (string.IsNullOrEmpty(aadAccessToken)) + { + // If Azure Identity fails, return empty credential for anonymous access + Utils.WriteVerboseOnCmdlet(_cmdletPassedIn, "No AAD token available; attempting anonymous access."); + return new Credential(); + } + + tenantId = null; + } + + // Exchange AAD access token for ACR refresh token via OAuth2 exchange endpoint + Utils.WriteVerboseOnCmdlet(_cmdletPassedIn, "Exchanging AAD access token for ACR refresh token."); + try + { + string refreshToken = await ExchangeForAcrRefreshTokenAsync(aadAccessToken, tenantId, cancellationToken).ConfigureAwait(false); + + if (string.IsNullOrEmpty(refreshToken)) + { + Utils.WriteWarningOnCmdlet(_cmdletPassedIn, "Failed to obtain ACR refresh token from exchange."); + return new Credential(); + } + + _cachedCredential = new Credential(RefreshToken: refreshToken); + _tokenExpiry = DateTimeOffset.UtcNow.AddMinutes(55); // ACR tokens typically valid for ~60 min + return _cachedCredential; + } + catch (Exception ex) + { + Utils.WriteWarningOnCmdlet(_cmdletPassedIn, $"Failed to exchange AAD token for ACR refresh token: {ex.Message}"); + return new Credential(); + } + } + + /// + /// Exchanges an AAD access token for an ACR refresh token via the OAuth2 exchange endpoint. + /// + private async Task ExchangeForAcrRefreshTokenAsync(string aadAccessToken, string tenantId, CancellationToken cancellationToken) + { + string exchangeUrl = string.Format(OAuthExchangeUrlTemplate, _registryHost); + string requestBody = string.IsNullOrEmpty(tenantId) + ? string.Format(RefreshTokenRequestBodyNoTenantTemplate, _registryHost, aadAccessToken) + : string.Format(RefreshTokenRequestBodyTemplate, _registryHost, tenantId, aadAccessToken); + + using var content = new StringContent(requestBody, System.Text.Encoding.UTF8, "application/x-www-form-urlencoded"); + using var response = await _httpClient.PostAsync(exchangeUrl, content, cancellationToken).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + + string responseBody = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + using var jsonDoc = JsonDocument.Parse(responseBody); + + if (jsonDoc.RootElement.TryGetProperty("refresh_token", out JsonElement refreshTokenElement)) + { + return refreshTokenElement.GetString(); + } + + return null; + } + } +} diff --git a/src/code/PSScriptMetadata.cs b/src/code/PSScriptMetadata.cs index a9aa9995e..597eead3e 100644 --- a/src/code/PSScriptMetadata.cs +++ b/src/code/PSScriptMetadata.cs @@ -115,7 +115,7 @@ public PSScriptMetadata( } Version = !String.IsNullOrEmpty(version) ? new NuGetVersion (version) : new NuGetVersion("1.0.0.0"); - Guid = (guid == null || guid == Guid.Empty) ? Guid.NewGuid() : guid; + Guid = guid; Author = !String.IsNullOrEmpty(author) ? author : Environment.UserName; CompanyName = companyName; Copyright = copyright; diff --git a/src/code/PublishHelper.cs b/src/code/PublishHelper.cs index abdced37b..ae2633196 100644 --- a/src/code/PublishHelper.cs +++ b/src/code/PublishHelper.cs @@ -355,7 +355,12 @@ internal void PushResource(string Repository, string modulePrefix, bool SkipDepe return; } - bool isAllowed = GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(repository.Uri); + bool isAllowed = true; + + if (OperatingSystem.IsWindows()) + { + isAllowed = GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(repository.Uri); + } if (!isAllowed) { diff --git a/src/code/RepositorySettings.cs b/src/code/RepositorySettings.cs index 7cf4f9261..957ace5c2 100644 --- a/src/code/RepositorySettings.cs +++ b/src/code/RepositorySettings.cs @@ -10,6 +10,7 @@ using System.Xml; using System.Xml.Linq; using Microsoft.PowerShell.PSResourceGet.Cmdlets; +using NuGet.Protocol.Core.Types; using static Microsoft.PowerShell.PSResourceGet.UtilClasses.PSRepositoryInfo; namespace Microsoft.PowerShell.PSResourceGet.UtilClasses @@ -288,7 +289,12 @@ public static PSRepositoryInfo Add(string repoName, Uri repoUri, int repoPriorit throw new PSInvalidOperationException(String.Format("Adding to repository store failed: {0}", e.Message)); } - bool isAllowed = GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled() ? GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(repoUri) : true; + bool isAllowed = true; + + if (OperatingSystem.IsWindows()) + { + isAllowed = GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled() ? GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(repoUri) : true; + } return new PSRepositoryInfo(repoName, repoUri, repoPriority, repoTrusted, repoCredentialInfo, repoCredentialProvider, apiVersion, isAllowed); } @@ -447,10 +453,13 @@ public static PSRepositoryInfo Update(string repoName, Uri repoUri, int repoPrio node.Attribute(PSCredentialInfo.SecretNameAttribute).Value); } - if (GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled()) + if (OperatingSystem.IsWindows()) { - var allowedList = GroupPolicyRepositoryEnforcement.GetAllowedRepositoryURIs(); + if (GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled()) + { + var allowedList = GroupPolicyRepositoryEnforcement.GetAllowedRepositoryURIs(); + } } // Update CredentialProvider if necessary @@ -468,7 +477,12 @@ public static PSRepositoryInfo Update(string repoName, Uri repoUri, int repoPrio } } - bool isAllowed = GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled() ? GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(thisUrl) : true; + bool isAllowed = true; + + if (OperatingSystem.IsWindows()) + { + isAllowed = GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled() ? GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(thisUrl) : true; + } updatedRepo = new PSRepositoryInfo(repoName, thisUrl, @@ -564,7 +578,12 @@ public static List Remove(string[] repoNames, out string[] err string attributeUrlUriName = urlAttributeExists ? "Url" : "Uri"; Uri repoUri = new Uri(node.Attribute(attributeUrlUriName).Value); - bool isAllowed = GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled() ? GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(repoUri) : true; + bool isAllowed = true; + + if (OperatingSystem.IsWindows()) + { + isAllowed = GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled() ? GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(repoUri) : true; + } removedRepos.Add( new PSRepositoryInfo(repo, @@ -704,7 +723,11 @@ public static List Read(string[] repoNames, out string[] error continue; } - bool isAllowed = GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled() ? GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(thisUrl) : true; + bool isAllowed = true; + if (OperatingSystem.IsWindows()) + { + isAllowed = GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled() ? GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(thisUrl) : true; + } PSRepositoryInfo currentRepoItem = new PSRepositoryInfo(repo.Attribute("Name").Value, thisUrl, @@ -817,7 +840,12 @@ public static List Read(string[] repoNames, out string[] error continue; } - bool isAllowed = GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled() ? GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(thisUrl) : true; + bool isAllowed = true; + + if (OperatingSystem.IsWindows()) + { + isAllowed = GroupPolicyRepositoryEnforcement.IsGroupPolicyEnabled() ? GroupPolicyRepositoryEnforcement.IsRepositoryAllowed(thisUrl) : true; + } PSRepositoryInfo currentRepoItem = new PSRepositoryInfo(node.Attribute("Name").Value, thisUrl, diff --git a/src/code/UpdateModuleManifest.cs b/src/code/UpdateModuleManifest.cs index aa8d38e88..3c9093aec 100644 --- a/src/code/UpdateModuleManifest.cs +++ b/src/code/UpdateModuleManifest.cs @@ -604,7 +604,7 @@ private void CreateModuleManifestHelper(Hashtable parsedMetadata, string resolve parsedMetadata["Prerelease"] = Prerelease; } - if (RequireLicenseAcceptance != null && RequireLicenseAcceptance.IsPresent) + if (RequireLicenseAcceptance.IsPresent) { parsedMetadata["RequireLicenseAcceptance"] = RequireLicenseAcceptance; } @@ -953,7 +953,7 @@ private void CreateModuleManifestForWinPSHelper(Hashtable parsedMetadata, string prerelease = Prerelease; } - if (RequireLicenseAcceptance != null && RequireLicenseAcceptance.IsPresent) + if (RequireLicenseAcceptance.IsPresent) { requireLicenseAcceptance = RequireLicenseAcceptance; } diff --git a/src/code/Utils.cs b/src/code/Utils.cs index 26d3ab25e..b2c6bb481 100644 --- a/src/code/Utils.cs +++ b/src/code/Utils.cs @@ -1426,7 +1426,7 @@ public static bool ValidateModuleManifest(string moduleManifestPath, out string return false; } } - + // Check for any errors from Test-ModuleManifest if (pwsh.HadErrors) { @@ -1459,6 +1459,22 @@ public static void WriteVerboseOnCmdlet( catch { } } + public static void WriteWarningOnCmdlet( + PSCmdlet cmdlet, + string message) + { + try + { + cmdlet.InvokeCommand.InvokeScript( + script: $"param ([string] $message) Write-Warning -Message $message", + useNewScope: true, + writeToPipeline: System.Management.Automation.Runspaces.PipelineResultTypes.None, + input: null, + args: new object[] { message }); + } + catch { } + } + /// /// Convert a json string into a hashtable object. /// This uses custom script to perform the PSObject -> Hashtable @@ -1653,7 +1669,9 @@ public static void DeleteDirectoryWithRestore(string dirPath) } catch (Exception e) { - throw e; + throw new PSInvalidOperationException( + $"An error occurred while attempting to delete the directory at path {dirPath} with restore. Error: {e.Message}", + e); } finally { @@ -1684,7 +1702,7 @@ public static void DeleteDirectory(string dirPath) { if (!Directory.Exists(dirPath)) { - throw new Exception($"Path '{dirPath}' that was attempting to be deleted does not exist."); + throw new PSInvalidOperationException($"Path '{dirPath}' that was attempting to be deleted does not exist."); } // Remove read only file attributes first @@ -1721,10 +1739,10 @@ public static void DeleteDirectory(string dirPath) if (ex.Message.Contains("The directory is not empty") && psVersion.StartsWith("5")) { // there is a known bug with WindowsPowerShell and OneDrive based module paths, where .NET Directory.Delete() will throw a 'The directory is not empty.' error. - throw new Exception(string.Format("Cannot uninstall module with OneDrive based path on Windows PowerShell due to .NET issue. Try installing and uninstalling using PowerShell 7+ if using OneDrive."), ex); + throw new PSInvalidOperationException("Cannot uninstall module with OneDrive based path on Windows PowerShell due to .NET issue. Try installing and uninstalling using PowerShell 7+ if using OneDrive.", ex); } - throw new Exception(string.Format("Access denied to path while deleting path {0}", dirPath), ex); + throw new PSInvalidOperationException(string.Format("Access denied to path while deleting path {0}", dirPath), ex); } else {