@@ -164,6 +164,8 @@ const (
164164 defaultallowedOutOfOrder = 10
165165 defaultSpoolingDownloadWorkers = 5
166166 defaulttrinoEncoding = "json"
167+ defaultSourceName = "trino-go-client"
168+ defaultKerberosServiceName = "trino"
167169)
168170
169171var (
@@ -195,7 +197,7 @@ type Config struct {
195197 ExtraCredentials map [string ]string // Extra credentials (optional)
196198 ClientTags []string // A comma-separated list of “tag” strings, used to identify Trino resource groups (optional)
197199 CustomClientName string // Custom client name (optional)
198- KerberosEnabled string // KerberosEnabled (optional, default is false)
200+ KerberosEnabled bool // KerberosEnabled (optional, default is false)
199201 KerberosKeytabPath string // Kerberos Keytab Path (optional)
200202 KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional)
201203 KerberosRemoteServiceName string // Trino coordinator Kerberos service name (optional)
@@ -209,8 +211,140 @@ type Config struct {
209211 QueryTimeout * time.Duration // Configurable timeout for query (optional)
210212}
211213
212- // FormatDSN returns a DSN string from the configuration.
214+ func (c * Config ) applyDefaults () {
215+ if c .Source == "" {
216+ c .Source = defaultSourceName
217+ }
218+
219+ if c .KerberosRemoteServiceName == "" && c .KerberosEnabled {
220+ c .KerberosRemoteServiceName = defaultKerberosServiceName
221+ }
222+ }
223+
224+ func ParseDSN (dsn string ) (* Config , error ) {
225+ serverURL , err := url .Parse (dsn )
226+ if err != nil {
227+ return nil , fmt .Errorf ("invalid DSN: %w" , err )
228+ }
229+
230+ query := serverURL .Query ()
231+ config := & Config {}
232+
233+ serverURI := serverURL .Scheme + "://"
234+ if serverURL .User != nil {
235+ serverURI += serverURL .User .String () + "@"
236+ }
237+
238+ serverURI += serverURL .Host
239+
240+ config .ServerURI = serverURI
241+ config .Source = query .Get ("source" )
242+
243+ config .Catalog = query .Get ("catalog" )
244+ config .Schema = query .Get ("schema" )
245+
246+ if sessionProps := query .Get ("session_properties" ); sessionProps != "" {
247+ var err error
248+ config .SessionProperties , err = parseMapParameter (sessionProps , "session property" , mapEntrySeparator , mapKeySeparator )
249+ if err != nil {
250+ return nil , err
251+ }
252+ }
253+
254+ if extraCreds := query .Get ("extra_credentials" ); extraCreds != "" {
255+ var err error
256+ config .ExtraCredentials , err = parseMapParameter (extraCreds , "extra credential" , mapEntrySeparator , mapKeySeparator )
257+ if err != nil {
258+ return nil , err
259+ }
260+ }
261+
262+ if clientTags := query .Get ("clientTags" ); clientTags != "" {
263+ config .ClientTags = strings .Split (clientTags , commaSeparator )
264+ }
265+
266+ config .CustomClientName = query .Get ("custom_client" )
267+ config .AccessToken = query .Get (accessTokenConfig )
268+
269+ if explicitPrepare := query .Get (explicitPrepareConfig ); explicitPrepare != "" {
270+ explicitPrepareValue , err := strconv .ParseBool (explicitPrepare )
271+ if err != nil {
272+ return nil , fmt .Errorf ("invalid boolean for %s: %q" , explicitPrepareConfig , explicitPrepare )
273+ }
274+ config .DisableExplicitPrepare = ! explicitPrepareValue
275+ }
276+
277+ if forwardAuth := query .Get (forwardAuthorizationHeaderConfig ); forwardAuth != "" {
278+ forwardAuthValue , err := strconv .ParseBool (forwardAuth )
279+ if err != nil {
280+ return nil , fmt .Errorf ("invalid boolean for %s: %q" , forwardAuthorizationHeaderConfig , forwardAuth )
281+ }
282+ config .ForwardAuthorizationHeader = forwardAuthValue
283+ }
284+
285+ if queryTimeoutStr := query .Get ("query_timeout" ); queryTimeoutStr != "" {
286+ queryTimeout , err := time .ParseDuration (queryTimeoutStr )
287+ if err != nil {
288+ return nil , fmt .Errorf ("trino: invalid timeout for query_timeout: %q" , queryTimeoutStr )
289+ }
290+ config .QueryTimeout = & queryTimeout
291+ }
292+
293+ if kerberosParam := query .Get (kerberosEnabledConfig ); kerberosParam != "" {
294+ enabled , err := strconv .ParseBool (kerberosParam )
295+ if err != nil {
296+ return nil , fmt .Errorf ("invalid boolean for %s: %q" , kerberosEnabledConfig , kerberosParam )
297+ }
298+ config .KerberosEnabled = enabled
299+ }
300+
301+ if kp := query .Get (kerberosKeytabPathConfig ); kp != "" {
302+ config .KerberosKeytabPath = kp
303+ }
304+
305+ if p := query .Get (kerberosPrincipalConfig ); p != "" {
306+ config .KerberosPrincipal = p
307+ }
308+
309+ if r := query .Get (kerberosRealmConfig ); r != "" {
310+ config .KerberosRealm = r
311+ }
312+
313+ if kp := query .Get (kerberosConfigPathConfig ); kp != "" {
314+ config .KerberosConfigPath = kp
315+ }
316+
317+ if rsn := query .Get (kerberosRemoteServiceNameConfig ); rsn != "" {
318+ config .KerberosRemoteServiceName = rsn
319+ }
320+
321+ if sslCertPath := query .Get (sslCertPathConfig ); sslCertPath != "" {
322+ config .SSLCertPath = sslCertPath
323+ }
324+
325+ if sslCert := query .Get (sslCertConfig ); sslCert != "" {
326+ config .SSLCert = sslCert
327+ }
328+
329+ config .applyDefaults ()
330+ return config , nil
331+ }
332+
333+ func parseMapParameter (value , paramName , entrySeparator , keyValueSeparator string ) (map [string ]string , error ) {
334+ result := make (map [string ]string )
335+ for _ , entry := range strings .Split (value , entrySeparator ) {
336+ parts := strings .SplitN (entry , keyValueSeparator , 2 )
337+ if len (parts ) != 2 {
338+ return nil , fmt .Errorf ("invalid %s entry: %q" , paramName , entry )
339+ }
340+ result [parts [0 ]] = parts [1 ]
341+ }
342+ return result , nil
343+ }
344+
213345func (c * Config ) FormatDSN () (string , error ) {
346+ c .applyDefaults ()
347+
214348 serverURL , err := url .Parse (c .ServerURI )
215349 if err != nil {
216350 return "" , err
@@ -227,18 +361,14 @@ func (c *Config) FormatDSN() (string, error) {
227361 credkv = append (credkv , k + mapKeySeparator + v )
228362 }
229363 }
230- source := c .Source
231- if source == "" {
232- source = "trino-go-client"
233- }
364+
234365 query := make (url.Values )
235- query .Add ("source" , source )
366+ query .Add ("source" , c . Source )
236367
237368 if c .ForwardAuthorizationHeader {
238369 query .Add (forwardAuthorizationHeaderConfig , "true" )
239370 }
240371
241- KerberosEnabled , _ := strconv .ParseBool (c .KerberosEnabled )
242372 isSSL := serverURL .Scheme == "https"
243373
244374 if c .DisableExplicitPrepare {
@@ -270,7 +400,7 @@ func (c *Config) FormatDSN() (string, error) {
270400 query .Add (sslCertConfig , c .SSLCert )
271401 }
272402
273- if KerberosEnabled {
403+ if c . KerberosEnabled {
274404 if ! isSSL {
275405 return "" , fmt .Errorf ("trino: client configuration error, SSL must be enabled for secure env" )
276406 }
@@ -279,11 +409,7 @@ func (c *Config) FormatDSN() (string, error) {
279409 query .Add (kerberosPrincipalConfig , c .KerberosPrincipal )
280410 query .Add (kerberosRealmConfig , c .KerberosRealm )
281411 query .Add (kerberosConfigPathConfig , c .KerberosConfigPath )
282- remoteServiceName := c .KerberosRemoteServiceName
283- if remoteServiceName == "" {
284- remoteServiceName = "trino"
285- }
286- query .Add (kerberosRemoteServiceNameConfig , remoteServiceName )
412+ query .Add (kerberosRemoteServiceNameConfig , c .KerberosRemoteServiceName )
287413 }
288414
289415 // ensure consistent order of items
@@ -333,52 +459,46 @@ var (
333459)
334460
335461func newConn (dsn string ) (* Conn , error ) {
336- serverURL , err := url . Parse (dsn )
462+ conf , err := ParseDSN (dsn )
337463 if err != nil {
338- return nil , fmt .Errorf ("trino: malformed dsn: %w" , err )
339- }
340-
341- query := serverURL .Query ()
342-
343- kerberosEnabled , _ := strconv .ParseBool (query .Get (kerberosEnabledConfig ))
344-
345- forwardAuthorizationHeader , _ := strconv .ParseBool (query .Get (forwardAuthorizationHeaderConfig ))
346-
347- useExplicitPrepare := true
348- if query .Get (explicitPrepareConfig ) != "" {
349- useExplicitPrepare , _ = strconv .ParseBool (query .Get (explicitPrepareConfig ))
464+ return nil , err
350465 }
351466
352467 var kerberosClient * client.Client
353468
354- if kerberosEnabled {
355- kt , err := keytab .Load (query . Get ( kerberosKeytabPathConfig ) )
469+ if conf . KerberosEnabled {
470+ kt , err := keytab .Load (conf . KerberosKeytabPath )
356471 if err != nil {
357472 return nil , fmt .Errorf ("trino: Error loading Keytab: %w" , err )
358473 }
359- conf , err := config .Load (query . Get ( kerberosConfigPathConfig ) )
474+ confKerb , err := config .Load (conf . KerberosConfigPath )
360475 if err != nil {
361476 return nil , fmt .Errorf ("trino: Error loading krb config: %w" , err )
362477 }
363478
364- kerberosClient = client .NewWithKeytab (query . Get ( kerberosPrincipalConfig ), query . Get ( kerberosRealmConfig ) , kt , conf )
479+ kerberosClient = client .NewWithKeytab (conf . KerberosPrincipal , conf . KerberosRealm , kt , confKerb )
365480 loginErr := kerberosClient .Login ()
366481 if loginErr != nil {
367482 return nil , fmt .Errorf ("trino: Error login to KDC: %v" , loginErr )
368483 }
369484 }
370485
486+ serverURL , err := url .Parse (conf .ServerURI )
487+ if err != nil {
488+ return nil , fmt .Errorf ("trino: invalid server URL: %w" , err )
489+ }
490+
371491 var httpClient = http .DefaultClient
372- if clientKey := query . Get ( "custom_client" ) ; clientKey != "" {
492+ if clientKey := conf . CustomClientName ; clientKey != "" {
373493 httpClient = getCustomClient (clientKey )
374494 if httpClient == nil {
375495 return nil , fmt .Errorf ("trino: custom client not registered: %q" , clientKey )
376496 }
377497 } else if serverURL .Scheme == "https" {
378498
379- cert := []byte (query . Get ( sslCertConfig ) )
499+ cert := []byte (conf . SSLCert )
380500
381- if certPath := query . Get ( sslCertPathConfig ) ; certPath != "" {
501+ if certPath := conf . SSLCertPath ; certPath != "" {
382502 cert , err = os .ReadFile (certPath )
383503 if err != nil {
384504 return nil , fmt .Errorf ("trino: Error loading SSL Cert File: %w" , err )
@@ -399,25 +519,16 @@ func newConn(dsn string) (*Conn, error) {
399519 }
400520 }
401521
402- var queryTimeout * time.Duration
403- if timeoutStr := query .Get ("query_timeout" ); timeoutStr != "" {
404- d , err := time .ParseDuration (timeoutStr )
405- if err != nil {
406- return nil , fmt .Errorf ("trino: invalid timeout: %w" , err )
407- }
408- queryTimeout = & d
409- }
410-
411522 c := & Conn {
412523 baseURL : serverURL .Scheme + "://" + serverURL .Host ,
413524 httpClient : * httpClient ,
414525 httpHeaders : make (http.Header ),
415526 kerberosClient : kerberosClient ,
416- kerberosEnabled : kerberosEnabled ,
417- kerberosRemoteServiceName : query . Get ( kerberosRemoteServiceNameConfig ) ,
418- useExplicitPrepare : useExplicitPrepare ,
419- forwardAuthorizationHeader : forwardAuthorizationHeader ,
420- queryTimeout : queryTimeout ,
527+ kerberosEnabled : conf . KerberosEnabled ,
528+ kerberosRemoteServiceName : conf . KerberosRemoteServiceName ,
529+ useExplicitPrepare : ! conf . DisableExplicitPrepare ,
530+ forwardAuthorizationHeader : conf . ForwardAuthorizationHeader ,
531+ queryTimeout : conf . QueryTimeout ,
421532 }
422533
423534 var user string
@@ -429,46 +540,42 @@ func newConn(dsn string) (*Conn, error) {
429540 }
430541 }
431542
432- if tags := query . Get ( "clientTags" ) ; tags != "" {
433- c .httpHeaders .Add (trinoTagsHeader , tags )
543+ if tags := conf . ClientTags ; tags != nil {
544+ c .httpHeaders .Add (trinoTagsHeader , strings . Join ( tags , commaSeparator ) )
434545 }
435546
436547 for k , v := range map [string ]string {
437548 trinoUserHeader : user ,
438- trinoSourceHeader : query . Get ( "source" ) ,
439- trinoCatalogHeader : query . Get ( "catalog" ) ,
440- trinoSchemaHeader : query . Get ( "schema" ) ,
441- authorizationHeader : getAuthorization (query . Get ( accessTokenConfig ) ),
549+ trinoSourceHeader : conf . Source ,
550+ trinoCatalogHeader : conf . Catalog ,
551+ trinoSchemaHeader : conf . Schema ,
552+ authorizationHeader : getAuthorization (conf . AccessToken ),
442553 } {
443554 if v != "" {
444555 c .httpHeaders .Add (k , v )
445556 }
446557 }
447- for header , param := range map [string ]string {
448- trinoSessionHeader : "session_properties" ,
449- trinoExtraCredentialHeader : "extra_credentials" ,
450- } {
451- v := query .Get (param )
452- if v != "" {
453- c .httpHeaders [header ], err = decodeMapHeader (param , v )
454- if err != nil {
455- return c , err
456- }
558+
559+ if conf .ExtraCredentials != nil {
560+ c .httpHeaders [trinoExtraCredentialHeader ], err = decodeMapHeader ("extra_credentials" , conf .ExtraCredentials )
561+ if err != nil {
562+ return c , err
563+ }
564+ }
565+
566+ if conf .SessionProperties != nil {
567+ c .httpHeaders [trinoSessionHeader ], err = decodeMapHeader ("session_properties" , conf .SessionProperties )
568+ if err != nil {
569+ return c , err
457570 }
458571 }
459572
460573 return c , nil
461574}
462575
463- func decodeMapHeader (name , input string ) ([]string , error ) {
464- result := []string {}
465- for _ , entry := range strings .Split (input , mapEntrySeparator ) {
466- parts := strings .SplitN (entry , mapKeySeparator , 2 )
467- if len (parts ) != 2 {
468- return nil , fmt .Errorf ("trino: Malformed %s: %s" , name , input )
469- }
470- key := parts [0 ]
471- value := parts [1 ]
576+ func decodeMapHeader (name string , m map [string ]string ) ([]string , error ) {
577+ result := make ([]string , 0 , len (m ))
578+ for key , value := range m {
472579 if len (key ) == 0 {
473580 return nil , fmt .Errorf ("trino: %s key is empty" , name )
474581 }
@@ -479,7 +586,6 @@ func decodeMapHeader(name, input string) ([]string, error) {
479586 return nil , fmt .Errorf ("trino: %s key '%s' contains spaces or is not printable ASCII" , name , key )
480587 }
481588 if ! isASCII (value ) {
482- // do not log value as it may contain sensitive information
483589 return nil , fmt .Errorf ("trino: %s value for key '%s' contains spaces or is not printable ASCII" , name , key )
484590 }
485591 result = append (result , key + "=" + url .QueryEscape (value ))
0 commit comments