diff --git a/src/FileTypeInterrogator/BaseFileTypeInterrogator.cs b/src/FileTypeInterrogator/BaseFileTypeInterrogator.cs index d5b21eb..98ad0c5 100644 --- a/src/FileTypeInterrogator/BaseFileTypeInterrogator.cs +++ b/src/FileTypeInterrogator/BaseFileTypeInterrogator.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; using System.Collections.Generic; using System.IO; using System.Linq; @@ -11,6 +12,9 @@ namespace FileTypeInterrogator /// public abstract class BaseFileTypeInterrogator : IFileTypeInterrogator { + private static readonly UTF8Encoding Utf8WithBomEncoding = new UTF8Encoding(true, true); + private static readonly UTF8Encoding Utf8WithoutBomEncoding = new UTF8Encoding(false, true); + private static readonly byte[] Utf8Bom = Utf8WithBomEncoding.GetPreamble(); private readonly Lazy> lazyFileTypes; private readonly FileTypeInfo asciiFileType = new FileTypeInfo("ASCII Text", "txt", "text/plain", null); private readonly FileTypeInfo utf8FileType = new FileTypeInfo("UTF-8 Text", "txt", "text/plain", null); @@ -43,13 +47,33 @@ public FileTypeInfo DetectType(Stream inputStream) if (inputStream.CanSeek) inputStream.Position = 0; - byte[] byteBuffer = new byte[inputStream.Length]; - _ = inputStream.Read(byteBuffer, 0, byteBuffer.Length); + long streamLength = inputStream.Length; + if (streamLength > int.MaxValue) + throw new NotSupportedException("Streams larger than 2 GB are not supported."); - if (inputStream.CanSeek) - inputStream.Position = 0; + int bufferSize = (int)streamLength; + byte[] byteBuffer = ArrayPool.Shared.Rent(bufferSize); + try + { + int bytesRead = 0; + while (bytesRead < bufferSize) + { + int read = inputStream.Read(byteBuffer, bytesRead, bufferSize - bytesRead); + if (read == 0) + break; + + bytesRead += read; + } + + return DetectType(byteBuffer, bytesRead); + } + finally + { + if (inputStream.CanSeek) + inputStream.Position = 0; - return DetectType(byteBuffer); + ArrayPool.Shared.Return(byteBuffer); + } } /// @@ -62,21 +86,31 @@ public FileTypeInfo DetectType(byte[] fileContent) if (fileContent == null) throw new ArgumentNullException(nameof(fileContent)); - if (fileContent.Length == 0) + return DetectType(fileContent, fileContent.Length); + } + + private FileTypeInfo DetectType(byte[] fileContent, int length) + { + if (fileContent == null) + throw new ArgumentNullException(nameof(fileContent)); + + if (length == 0) throw new ArgumentException("input must not be empty"); + ReadOnlySpan input = fileContent.AsSpan(0, length); + // iterate over each type and determine if we have a match based on file signature. foreach (var fileTypeInfo in AvailableTypes) { // if we found a match return the matching filetypeinfo - if (IsMatchingType(fileContent, fileTypeInfo)) + if (IsMatchingType(input, fileTypeInfo)) return fileTypeInfo; } - if (IsAscii(fileContent)) + if (IsAscii(input)) return asciiFileType; - if (IsUTF8(fileContent, out bool hasBOM)) + if (IsUTF8(fileContent, length, out bool hasBOM)) return hasBOM ? utf8FileTypeWithBOM : utf8FileType; return null; @@ -108,13 +142,23 @@ public IEnumerable GetAvailableMimeTypes() /// public bool IsType(byte[] fileContent, string extensionAliasOrMimeType) { - foreach (var fileTypeInfo in AvailableTypes.Where(t => - t.FileType.Equals(extensionAliasOrMimeType, StringComparison.OrdinalIgnoreCase) || - t.MimeType.Equals(extensionAliasOrMimeType, StringComparison.OrdinalIgnoreCase) || - (t.Alias != null && t.Alias.Contains(extensionAliasOrMimeType, StringComparer.OrdinalIgnoreCase)))) + if (fileContent == null) + throw new ArgumentNullException(nameof(fileContent)); + + return IsType(fileContent, fileContent.Length, extensionAliasOrMimeType); + } + + private bool IsType(byte[] fileContent, int length, string extensionAliasOrMimeType) + { + foreach (var fileTypeInfo in AvailableTypes) { - if (IsMatchingType(fileContent, fileTypeInfo)) - return true; + if (fileTypeInfo.FileType.Equals(extensionAliasOrMimeType, StringComparison.OrdinalIgnoreCase) || + fileTypeInfo.MimeType.Equals(extensionAliasOrMimeType, StringComparison.OrdinalIgnoreCase) || + (fileTypeInfo.Alias != null && fileTypeInfo.Alias.Contains(extensionAliasOrMimeType, StringComparer.OrdinalIgnoreCase))) + { + if (IsMatchingType(fileContent.AsSpan(0, length), fileTypeInfo)) + return true; + } } if (extensionAliasOrMimeType.Equals("txt", StringComparison.OrdinalIgnoreCase) || @@ -131,23 +175,18 @@ private static bool IsMatchingType(ReadOnlySpan input, FileTypeInfo type) // some file types have the same header // but different signature in another location, if its one of these determine what the true file type is - if (isMatch && type.SubHeader != null) + int subHeaderLength = type.SubHeader?.Length ?? 0; + if (isMatch && subHeaderLength > 0) { - // find all indices of matching the 1st byte of the additional sequence - var matchingIndices = new List(); - for (int i = 0; i < input.Length; i++) + isMatch = false; + for (int i = 0; i <= input.Length - subHeaderLength; i++) { if (input[i] == type.SubHeader[0]) - matchingIndices.Add(i); - } - - // investigate all of them for a match - foreach (int potentialMatchingIndex in matchingIndices) - { - isMatch = FindMatch(input, type.SubHeader, potentialMatchingIndex); - - if (isMatch) - break; + { + isMatch = FindMatch(input, type.SubHeader, i); + if (isMatch) + break; + } } } @@ -231,7 +270,7 @@ private static bool IsText(byte[] input, out bool hasBOM) bool isAscii = IsAscii(input); - return isAscii || IsUTF8(input, out hasBOM); + return isAscii || IsUTF8(input, input.Length, out hasBOM); } private static bool IsAscii(ReadOnlySpan input) @@ -245,20 +284,19 @@ private static bool IsAscii(ReadOnlySpan input) return true; } - private static bool IsUTF8(byte[] input, out bool hasBOM) + private static bool IsUTF8(byte[] input, int length, out bool hasBOM) { - UTF8Encoding utf8WithBOM = new UTF8Encoding(true, true); bool isUTF8 = true; - byte[] bom = utf8WithBOM.GetPreamble(); - int bomLength = bom.Length; + int bomLength = Utf8Bom.Length; hasBOM = false; - if (input.Length >= bomLength && bom.SequenceEqual(input.Take(bomLength))) + ReadOnlySpan inputSpan = input.AsSpan(0, length); + if (length >= bomLength && inputSpan.Slice(0, bomLength).SequenceEqual(Utf8Bom)) { try { - utf8WithBOM.GetString(input, bomLength, input.Length - bomLength); + Utf8WithBomEncoding.GetString(input, bomLength, length - bomLength); hasBOM = true; } catch (ArgumentException) @@ -270,10 +308,9 @@ private static bool IsUTF8(byte[] input, out bool hasBOM) if (isUTF8 && !hasBOM) { - UTF8Encoding utf8WithoutBOM = new UTF8Encoding(false, true); try { - utf8WithoutBOM.GetString(input, 0, input.Length); + Utf8WithoutBomEncoding.GetString(input, 0, length); isUTF8 = true; } catch (ArgumentException)