Skip to content

Commit d8a3792

Browse files
Flgadonineinchnick
authored andcommitted
Add DSN into Config struct
1 parent 66c0bd8 commit d8a3792

File tree

2 files changed

+328
-76
lines changed

2 files changed

+328
-76
lines changed

trino/trino.go

Lines changed: 180 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ const (
164164
defaultallowedOutOfOrder = 10
165165
defaultSpoolingDownloadWorkers = 5
166166
defaulttrinoEncoding = "json"
167+
defaultSourceName = "trino-go-client"
168+
defaultKerberosServiceName = "trino"
167169
)
168170

169171
var (
@@ -195,7 +197,7 @@ type Config struct {
195197
ExtraCredentials map[string]string // Extra credentials (optional)
196198
ClientTags []string // A comma-separated list of “tag” strings, used to identify Trino resource groups (optional)
197199
CustomClientName string // Custom client name (optional)
198-
KerberosEnabled string // KerberosEnabled (optional, default is false)
200+
KerberosEnabled bool // KerberosEnabled (optional, default is false)
199201
KerberosKeytabPath string // Kerberos Keytab Path (optional)
200202
KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional)
201203
KerberosRemoteServiceName string // Trino coordinator Kerberos service name (optional)
@@ -209,8 +211,140 @@ type Config struct {
209211
QueryTimeout *time.Duration // Configurable timeout for query (optional)
210212
}
211213

212-
// FormatDSN returns a DSN string from the configuration.
214+
func (c *Config) applyDefaults() {
215+
if c.Source == "" {
216+
c.Source = defaultSourceName
217+
}
218+
219+
if c.KerberosRemoteServiceName == "" && c.KerberosEnabled {
220+
c.KerberosRemoteServiceName = defaultKerberosServiceName
221+
}
222+
}
223+
224+
func ParseDSN(dsn string) (*Config, error) {
225+
serverURL, err := url.Parse(dsn)
226+
if err != nil {
227+
return nil, fmt.Errorf("invalid DSN: %w", err)
228+
}
229+
230+
query := serverURL.Query()
231+
config := &Config{}
232+
233+
serverURI := serverURL.Scheme + "://"
234+
if serverURL.User != nil {
235+
serverURI += serverURL.User.String() + "@"
236+
}
237+
238+
serverURI += serverURL.Host
239+
240+
config.ServerURI = serverURI
241+
config.Source = query.Get("source")
242+
243+
config.Catalog = query.Get("catalog")
244+
config.Schema = query.Get("schema")
245+
246+
if sessionProps := query.Get("session_properties"); sessionProps != "" {
247+
var err error
248+
config.SessionProperties, err = parseMapParameter(sessionProps, "session property", mapEntrySeparator, mapKeySeparator)
249+
if err != nil {
250+
return nil, err
251+
}
252+
}
253+
254+
if extraCreds := query.Get("extra_credentials"); extraCreds != "" {
255+
var err error
256+
config.ExtraCredentials, err = parseMapParameter(extraCreds, "extra credential", mapEntrySeparator, mapKeySeparator)
257+
if err != nil {
258+
return nil, err
259+
}
260+
}
261+
262+
if clientTags := query.Get("clientTags"); clientTags != "" {
263+
config.ClientTags = strings.Split(clientTags, commaSeparator)
264+
}
265+
266+
config.CustomClientName = query.Get("custom_client")
267+
config.AccessToken = query.Get(accessTokenConfig)
268+
269+
if explicitPrepare := query.Get(explicitPrepareConfig); explicitPrepare != "" {
270+
explicitPrepareValue, err := strconv.ParseBool(explicitPrepare)
271+
if err != nil {
272+
return nil, fmt.Errorf("invalid boolean for %s: %q", explicitPrepareConfig, explicitPrepare)
273+
}
274+
config.DisableExplicitPrepare = !explicitPrepareValue
275+
}
276+
277+
if forwardAuth := query.Get(forwardAuthorizationHeaderConfig); forwardAuth != "" {
278+
forwardAuthValue, err := strconv.ParseBool(forwardAuth)
279+
if err != nil {
280+
return nil, fmt.Errorf("invalid boolean for %s: %q", forwardAuthorizationHeaderConfig, forwardAuth)
281+
}
282+
config.ForwardAuthorizationHeader = forwardAuthValue
283+
}
284+
285+
if queryTimeoutStr := query.Get("query_timeout"); queryTimeoutStr != "" {
286+
queryTimeout, err := time.ParseDuration(queryTimeoutStr)
287+
if err != nil {
288+
return nil, fmt.Errorf("trino: invalid timeout for query_timeout: %q", queryTimeoutStr)
289+
}
290+
config.QueryTimeout = &queryTimeout
291+
}
292+
293+
if kerberosParam := query.Get(kerberosEnabledConfig); kerberosParam != "" {
294+
enabled, err := strconv.ParseBool(kerberosParam)
295+
if err != nil {
296+
return nil, fmt.Errorf("invalid boolean for %s: %q", kerberosEnabledConfig, kerberosParam)
297+
}
298+
config.KerberosEnabled = enabled
299+
}
300+
301+
if kp := query.Get(kerberosKeytabPathConfig); kp != "" {
302+
config.KerberosKeytabPath = kp
303+
}
304+
305+
if p := query.Get(kerberosPrincipalConfig); p != "" {
306+
config.KerberosPrincipal = p
307+
}
308+
309+
if r := query.Get(kerberosRealmConfig); r != "" {
310+
config.KerberosRealm = r
311+
}
312+
313+
if kp := query.Get(kerberosConfigPathConfig); kp != "" {
314+
config.KerberosConfigPath = kp
315+
}
316+
317+
if rsn := query.Get(kerberosRemoteServiceNameConfig); rsn != "" {
318+
config.KerberosRemoteServiceName = rsn
319+
}
320+
321+
if sslCertPath := query.Get(sslCertPathConfig); sslCertPath != "" {
322+
config.SSLCertPath = sslCertPath
323+
}
324+
325+
if sslCert := query.Get(sslCertConfig); sslCert != "" {
326+
config.SSLCert = sslCert
327+
}
328+
329+
config.applyDefaults()
330+
return config, nil
331+
}
332+
333+
func parseMapParameter(value, paramName, entrySeparator, keyValueSeparator string) (map[string]string, error) {
334+
result := make(map[string]string)
335+
for _, entry := range strings.Split(value, entrySeparator) {
336+
parts := strings.SplitN(entry, keyValueSeparator, 2)
337+
if len(parts) != 2 {
338+
return nil, fmt.Errorf("invalid %s entry: %q", paramName, entry)
339+
}
340+
result[parts[0]] = parts[1]
341+
}
342+
return result, nil
343+
}
344+
213345
func (c *Config) FormatDSN() (string, error) {
346+
c.applyDefaults()
347+
214348
serverURL, err := url.Parse(c.ServerURI)
215349
if err != nil {
216350
return "", err
@@ -227,18 +361,14 @@ func (c *Config) FormatDSN() (string, error) {
227361
credkv = append(credkv, k+mapKeySeparator+v)
228362
}
229363
}
230-
source := c.Source
231-
if source == "" {
232-
source = "trino-go-client"
233-
}
364+
234365
query := make(url.Values)
235-
query.Add("source", source)
366+
query.Add("source", c.Source)
236367

237368
if c.ForwardAuthorizationHeader {
238369
query.Add(forwardAuthorizationHeaderConfig, "true")
239370
}
240371

241-
KerberosEnabled, _ := strconv.ParseBool(c.KerberosEnabled)
242372
isSSL := serverURL.Scheme == "https"
243373

244374
if c.DisableExplicitPrepare {
@@ -270,7 +400,7 @@ func (c *Config) FormatDSN() (string, error) {
270400
query.Add(sslCertConfig, c.SSLCert)
271401
}
272402

273-
if KerberosEnabled {
403+
if c.KerberosEnabled {
274404
if !isSSL {
275405
return "", fmt.Errorf("trino: client configuration error, SSL must be enabled for secure env")
276406
}
@@ -279,11 +409,7 @@ func (c *Config) FormatDSN() (string, error) {
279409
query.Add(kerberosPrincipalConfig, c.KerberosPrincipal)
280410
query.Add(kerberosRealmConfig, c.KerberosRealm)
281411
query.Add(kerberosConfigPathConfig, c.KerberosConfigPath)
282-
remoteServiceName := c.KerberosRemoteServiceName
283-
if remoteServiceName == "" {
284-
remoteServiceName = "trino"
285-
}
286-
query.Add(kerberosRemoteServiceNameConfig, remoteServiceName)
412+
query.Add(kerberosRemoteServiceNameConfig, c.KerberosRemoteServiceName)
287413
}
288414

289415
// ensure consistent order of items
@@ -333,52 +459,46 @@ var (
333459
)
334460

335461
func newConn(dsn string) (*Conn, error) {
336-
serverURL, err := url.Parse(dsn)
462+
conf, err := ParseDSN(dsn)
337463
if err != nil {
338-
return nil, fmt.Errorf("trino: malformed dsn: %w", err)
339-
}
340-
341-
query := serverURL.Query()
342-
343-
kerberosEnabled, _ := strconv.ParseBool(query.Get(kerberosEnabledConfig))
344-
345-
forwardAuthorizationHeader, _ := strconv.ParseBool(query.Get(forwardAuthorizationHeaderConfig))
346-
347-
useExplicitPrepare := true
348-
if query.Get(explicitPrepareConfig) != "" {
349-
useExplicitPrepare, _ = strconv.ParseBool(query.Get(explicitPrepareConfig))
464+
return nil, err
350465
}
351466

352467
var kerberosClient *client.Client
353468

354-
if kerberosEnabled {
355-
kt, err := keytab.Load(query.Get(kerberosKeytabPathConfig))
469+
if conf.KerberosEnabled {
470+
kt, err := keytab.Load(conf.KerberosKeytabPath)
356471
if err != nil {
357472
return nil, fmt.Errorf("trino: Error loading Keytab: %w", err)
358473
}
359-
conf, err := config.Load(query.Get(kerberosConfigPathConfig))
474+
confKerb, err := config.Load(conf.KerberosConfigPath)
360475
if err != nil {
361476
return nil, fmt.Errorf("trino: Error loading krb config: %w", err)
362477
}
363478

364-
kerberosClient = client.NewWithKeytab(query.Get(kerberosPrincipalConfig), query.Get(kerberosRealmConfig), kt, conf)
479+
kerberosClient = client.NewWithKeytab(conf.KerberosPrincipal, conf.KerberosRealm, kt, confKerb)
365480
loginErr := kerberosClient.Login()
366481
if loginErr != nil {
367482
return nil, fmt.Errorf("trino: Error login to KDC: %v", loginErr)
368483
}
369484
}
370485

486+
serverURL, err := url.Parse(conf.ServerURI)
487+
if err != nil {
488+
return nil, fmt.Errorf("trino: invalid server URL: %w", err)
489+
}
490+
371491
var httpClient = http.DefaultClient
372-
if clientKey := query.Get("custom_client"); clientKey != "" {
492+
if clientKey := conf.CustomClientName; clientKey != "" {
373493
httpClient = getCustomClient(clientKey)
374494
if httpClient == nil {
375495
return nil, fmt.Errorf("trino: custom client not registered: %q", clientKey)
376496
}
377497
} else if serverURL.Scheme == "https" {
378498

379-
cert := []byte(query.Get(sslCertConfig))
499+
cert := []byte(conf.SSLCert)
380500

381-
if certPath := query.Get(sslCertPathConfig); certPath != "" {
501+
if certPath := conf.SSLCertPath; certPath != "" {
382502
cert, err = os.ReadFile(certPath)
383503
if err != nil {
384504
return nil, fmt.Errorf("trino: Error loading SSL Cert File: %w", err)
@@ -399,25 +519,16 @@ func newConn(dsn string) (*Conn, error) {
399519
}
400520
}
401521

402-
var queryTimeout *time.Duration
403-
if timeoutStr := query.Get("query_timeout"); timeoutStr != "" {
404-
d, err := time.ParseDuration(timeoutStr)
405-
if err != nil {
406-
return nil, fmt.Errorf("trino: invalid timeout: %w", err)
407-
}
408-
queryTimeout = &d
409-
}
410-
411522
c := &Conn{
412523
baseURL: serverURL.Scheme + "://" + serverURL.Host,
413524
httpClient: *httpClient,
414525
httpHeaders: make(http.Header),
415526
kerberosClient: kerberosClient,
416-
kerberosEnabled: kerberosEnabled,
417-
kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig),
418-
useExplicitPrepare: useExplicitPrepare,
419-
forwardAuthorizationHeader: forwardAuthorizationHeader,
420-
queryTimeout: queryTimeout,
527+
kerberosEnabled: conf.KerberosEnabled,
528+
kerberosRemoteServiceName: conf.KerberosRemoteServiceName,
529+
useExplicitPrepare: !conf.DisableExplicitPrepare,
530+
forwardAuthorizationHeader: conf.ForwardAuthorizationHeader,
531+
queryTimeout: conf.QueryTimeout,
421532
}
422533

423534
var user string
@@ -429,46 +540,42 @@ func newConn(dsn string) (*Conn, error) {
429540
}
430541
}
431542

432-
if tags := query.Get("clientTags"); tags != "" {
433-
c.httpHeaders.Add(trinoTagsHeader, tags)
543+
if tags := conf.ClientTags; tags != nil {
544+
c.httpHeaders.Add(trinoTagsHeader, strings.Join(tags, commaSeparator))
434545
}
435546

436547
for k, v := range map[string]string{
437548
trinoUserHeader: user,
438-
trinoSourceHeader: query.Get("source"),
439-
trinoCatalogHeader: query.Get("catalog"),
440-
trinoSchemaHeader: query.Get("schema"),
441-
authorizationHeader: getAuthorization(query.Get(accessTokenConfig)),
549+
trinoSourceHeader: conf.Source,
550+
trinoCatalogHeader: conf.Catalog,
551+
trinoSchemaHeader: conf.Schema,
552+
authorizationHeader: getAuthorization(conf.AccessToken),
442553
} {
443554
if v != "" {
444555
c.httpHeaders.Add(k, v)
445556
}
446557
}
447-
for header, param := range map[string]string{
448-
trinoSessionHeader: "session_properties",
449-
trinoExtraCredentialHeader: "extra_credentials",
450-
} {
451-
v := query.Get(param)
452-
if v != "" {
453-
c.httpHeaders[header], err = decodeMapHeader(param, v)
454-
if err != nil {
455-
return c, err
456-
}
558+
559+
if conf.ExtraCredentials != nil {
560+
c.httpHeaders[trinoExtraCredentialHeader], err = decodeMapHeader("extra_credentials", conf.ExtraCredentials)
561+
if err != nil {
562+
return c, err
563+
}
564+
}
565+
566+
if conf.SessionProperties != nil {
567+
c.httpHeaders[trinoSessionHeader], err = decodeMapHeader("session_properties", conf.SessionProperties)
568+
if err != nil {
569+
return c, err
457570
}
458571
}
459572

460573
return c, nil
461574
}
462575

463-
func decodeMapHeader(name, input string) ([]string, error) {
464-
result := []string{}
465-
for _, entry := range strings.Split(input, mapEntrySeparator) {
466-
parts := strings.SplitN(entry, mapKeySeparator, 2)
467-
if len(parts) != 2 {
468-
return nil, fmt.Errorf("trino: Malformed %s: %s", name, input)
469-
}
470-
key := parts[0]
471-
value := parts[1]
576+
func decodeMapHeader(name string, m map[string]string) ([]string, error) {
577+
result := make([]string, 0, len(m))
578+
for key, value := range m {
472579
if len(key) == 0 {
473580
return nil, fmt.Errorf("trino: %s key is empty", name)
474581
}
@@ -479,7 +586,6 @@ func decodeMapHeader(name, input string) ([]string, error) {
479586
return nil, fmt.Errorf("trino: %s key '%s' contains spaces or is not printable ASCII", name, key)
480587
}
481588
if !isASCII(value) {
482-
// do not log value as it may contain sensitive information
483589
return nil, fmt.Errorf("trino: %s value for key '%s' contains spaces or is not printable ASCII", name, key)
484590
}
485591
result = append(result, key+"="+url.QueryEscape(value))

0 commit comments

Comments
 (0)