Skip to content

Commit d07a4d8

Browse files
committed
Reduced logic to only fix ommited standard ports.
1 parent c447ac6 commit d07a4d8

File tree

1 file changed

+41
-30
lines changed

1 file changed

+41
-30
lines changed

src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,40 +18,42 @@
1818

1919

2020
def _extract_hostname(netloc: str) -> str:
21-
"""
22-
Extract hostname from netloc, ignoring port number.
23-
24-
Args:
25-
netloc: Network location string (e.g., "localhost:8080" or "example.com")
26-
27-
Returns:
28-
Hostname without port (e.g., "localhost" or "example.com")
29-
30-
"""
21+
"""Extract hostname from netloc."""
3122
if ":" in netloc:
3223
if netloc.startswith("["):
3324
# IPv6 with port: [::1]:8080
3425
end_bracket = netloc.rfind("]")
3526
if end_bracket != -1:
3627
return netloc[: end_bracket + 1]
37-
# Regular hostname with port: localhost:8080
3828
return netloc.split(":", 1)[0]
3929
return netloc
4030

4131

42-
def _hostnames_match(hostname1: str, hostname2: str) -> bool:
32+
def _netlocs_match(netloc1: str, scheme1: str, netloc2: str, scheme2: str) -> bool:
4333
"""
44-
Check if two hostnames match, ignoring case and port.
45-
46-
Args:
47-
hostname1: First hostname (may include port)
48-
hostname2: Second hostname (may include port)
49-
50-
Returns:
51-
True if hostnames match (case-insensitive, ignoring port)
52-
34+
Check if two netlocs match. Ports must match exactly, but missing ports
35+
are assumed to be standard ports (80 for http, 443 for https).
5336
"""
54-
return _extract_hostname(hostname1).lower() == _extract_hostname(hostname2).lower()
37+
if _extract_hostname(netloc1).lower() != _extract_hostname(netloc2).lower():
38+
return False
39+
40+
def _get_port(netloc: str, scheme: str) -> int:
41+
if ":" in netloc:
42+
if netloc.startswith("["):
43+
end_bracket = netloc.rfind("]")
44+
if end_bracket != -1 and end_bracket + 1 < len(netloc):
45+
try:
46+
return int(netloc[end_bracket + 2 :])
47+
except ValueError:
48+
pass
49+
else:
50+
try:
51+
return int(netloc.split(":", 1)[1])
52+
except ValueError:
53+
pass
54+
return 443 if scheme == "https" else 80
55+
56+
return _get_port(netloc1, scheme1) == _get_port(netloc2, scheme2)
5557

5658

5759
@dataclass
@@ -107,13 +109,19 @@ def _update_link(
107109

108110
parsed_link = urlparse(link["href"])
109111

110-
link_hostname = _extract_hostname(parsed_link.netloc)
111-
request_hostname = _extract_hostname(request_url.netloc)
112-
upstream_hostname = _extract_hostname(upstream_url.netloc)
113-
114112
if not (
115-
_hostnames_match(link_hostname, request_hostname)
116-
or _hostnames_match(link_hostname, upstream_hostname)
113+
_netlocs_match(
114+
parsed_link.netloc,
115+
parsed_link.scheme,
116+
request_url.netloc,
117+
request_url.scheme,
118+
)
119+
or _netlocs_match(
120+
parsed_link.netloc,
121+
parsed_link.scheme,
122+
upstream_url.netloc,
123+
upstream_url.scheme,
124+
)
117125
):
118126
logger.debug(
119127
"Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)",
@@ -135,8 +143,11 @@ def _update_link(
135143
return
136144

137145
# Replace the upstream host with the client's host
138-
link_matches_upstream = _hostnames_match(
139-
parsed_link.netloc, upstream_url.netloc
146+
link_matches_upstream = _netlocs_match(
147+
parsed_link.netloc,
148+
parsed_link.scheme,
149+
upstream_url.netloc,
150+
upstream_url.scheme,
140151
)
141152
parsed_link = parsed_link._replace(netloc=request_url.netloc)
142153
if link_matches_upstream:

0 commit comments

Comments
 (0)