diff --git a/main.go b/main.go index 60e5004..f484670 100644 --- a/main.go +++ b/main.go @@ -48,6 +48,10 @@ func main() { var preferHTTPS bool flag.BoolVar(&preferHTTPS, "prefer-https", false, "only try plain HTTP if HTTPS fails") + // filter out cloudflare error pages + var filterCloudflareErrors bool + flag.BoolVar(&filterCloudflareErrors, "filter-cf-errors", false, "Filter out Cloudflare error pages") + // HTTP method to use var method string flag.StringVar(&method, "method", "GET", "HTTP method to use") @@ -57,6 +61,13 @@ func main() { // make an actual time.Duration out of the timeout timeout := time.Duration(to * 1000000) + var filterStrings []string + + // Add Cloudflare signatures to filterStrings if filterCloudflareErrors + if filterCloudflareErrors { + filterStrings = append(filterStrings, "
cloudflare
", "cf_styles-css") + } + var tr = &http.Transport{ MaxIdleConns: 30, IdleConnTimeout: time.Second, @@ -96,7 +107,7 @@ func main() { // always try HTTPS first withProto := "https://" + url - if isListening(client, withProto, method) { + if isListening(client, withProto, method, filterStrings) { output <- withProto // skip trying HTTP if --prefer-https is set @@ -120,7 +131,7 @@ func main() { go func() { for url := range httpURLs { withProto := "http://" + url - if isListening(client, withProto, method) { + if isListening(client, withProto, method, filterStrings) { output <- withProto continue } @@ -157,6 +168,11 @@ func main() { for sc.Scan() { domain := strings.ToLower(sc.Text()) + // Skip unresolvable domains + if _, err := net.LookupIP(domain); err != nil { + continue + } + // submit standard port checks if !skipDefault { httpsURLs <- domain @@ -210,7 +226,7 @@ func main() { outputWG.Wait() } -func isListening(client *http.Client, url, method string) bool { +func isListening(client *http.Client, url, method string, filterStrings []string) bool { req, err := http.NewRequest(method, url, nil) if err != nil { @@ -222,8 +238,21 @@ func isListening(client *http.Client, url, method string) bool { resp, err := client.Do(req) if resp != nil { - io.Copy(ioutil.Discard, resp.Body) - resp.Body.Close() + defer resp.Body.Close() + + if len(filterStrings) != 0 { + // Read the first 512 bytes of the response and check for presence of any filter strings + peek := make([]byte, 512) + resp.Body.Read(peek) + peekStr := string(peek) + for _, filterString := range filterStrings { + if strings.Contains(peekStr, filterString) { + return true + } + } + } else { + io.Copy(ioutil.Discard, resp.Body) + } } if err != nil {