diff --git a/probe/cri/registry.go b/probe/cri/registry.go index c346ec59b..c09cb81b3 100644 --- a/probe/cri/registry.go +++ b/probe/cri/registry.go @@ -18,26 +18,32 @@ func dial(addr string, timeout time.Duration) (net.Conn, error) { } func getAddressAndDialer(endpoint string) (string, func(addr string, timeout time.Duration) (net.Conn, error), error) { - protocol, addr, err := parseEndpointWithFallbackProtocol(endpoint, unixProtocol) + addr, err := parseEndpointWithFallbackProtocol(endpoint, unixProtocol) if err != nil { return "", nil, err } - if protocol != unixProtocol { - return "", nil, fmt.Errorf("endpoint was not unix socket %v", protocol) - } return addr, dial, nil } -func parseEndpointWithFallbackProtocol(endpoint string, fallbackProtocol string) (protocol string, addr string, err error) { - if protocol, addr, err = parseEndpoint(endpoint); err != nil && protocol == "" { +func parseEndpointWithFallbackProtocol(endpoint string, fallbackProtocol string) (addr string, err error) { + var protocol string + + protocol, addr, err = parseEndpoint(endpoint) + + if err != nil { + return "", err + } + + if protocol == "" { fallbackEndpoint := fallbackProtocol + "://" + endpoint - protocol, addr, err = parseEndpoint(fallbackEndpoint) + _, addr, err = parseEndpoint(fallbackEndpoint) + if err != nil { - return "", "", err + return "", err } } - return + return addr, err } func parseEndpoint(endpoint string) (string, string, error) { @@ -47,11 +53,11 @@ func parseEndpoint(endpoint string) (string, string, error) { } if u.Scheme == "tcp" { - return "tcp", u.Host, nil + return "tcp", u.Host, fmt.Errorf("endpoint was not unix socket %v", u.Scheme) } else if u.Scheme == "unix" { return "unix", u.Path, nil } else if u.Scheme == "" { - return "", "", fmt.Errorf("Using %q as endpoint is deprecated, please consider using full url format", endpoint) + return "", "", nil } else { return u.Scheme, "", fmt.Errorf("protocol %q not supported", u.Scheme) }