Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions pkg/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ var (
// defaultSessionTimeout default: 24 hour
defaultSessionTimeout = 24 * time.Hour

// defaultWaitTimeoutMin default: 1 second
defaultWaitTimeoutMin int64 = 1
// defaultWaitTimeoutMax default: 24 hours (in seconds)
defaultWaitTimeoutMax int64 = 24 * 60 * 60

// defaultInteractiveTimeoutMin default: 1 second
defaultInteractiveTimeoutMin int64 = 1
// defaultInteractiveTimeoutMax default: 24 hours (in seconds)
defaultInteractiveTimeoutMax int64 = 24 * 60 * 60

// defaultNetReadTimeout default: 0 (no timeout for normal operations)
defaultNetReadTimeout = time.Duration(0)

Expand Down Expand Up @@ -287,6 +297,15 @@ type FrontendParameters struct {
// NetWriteTimeout is the timeout for writing to the network connection. Default is 60 seconds.
NetWriteTimeout toml.Duration `toml:"netWriteTimeout"`

// WaitTimeoutMin is the hard minimum for wait_timeout (seconds).
WaitTimeoutMin int64 `toml:"waitTimeoutMin"`
// WaitTimeoutMax is the hard maximum for wait_timeout (seconds).
WaitTimeoutMax int64 `toml:"waitTimeoutMax"`
// InteractiveTimeoutMin is the hard minimum for interactive_timeout (seconds).
InteractiveTimeoutMin int64 `toml:"interactiveTimeoutMin"`
// InteractiveTimeoutMax is the hard maximum for interactive_timeout (seconds).
InteractiveTimeoutMax int64 `toml:"interactiveTimeoutMax"`

// LoadLocalReadTimeout is the timeout for reading data from client during LOAD DATA LOCAL operations.
// Used to detect F5/LoadBalancer idle timeout disconnections. Default is 60 seconds.
LoadLocalReadTimeout toml.Duration `toml:"loadLocalReadTimeout"`
Expand Down Expand Up @@ -443,6 +462,19 @@ func (fp *FrontendParameters) SetDefaultValues() {
fp.NetWriteTimeout.Duration = defaultNetWriteTimeout
}

if fp.WaitTimeoutMin == 0 {
fp.WaitTimeoutMin = defaultWaitTimeoutMin
}
if fp.WaitTimeoutMax == 0 {
fp.WaitTimeoutMax = defaultWaitTimeoutMax
}
if fp.InteractiveTimeoutMin == 0 {
fp.InteractiveTimeoutMin = defaultInteractiveTimeoutMin
}
if fp.InteractiveTimeoutMax == 0 {
fp.InteractiveTimeoutMax = defaultInteractiveTimeoutMax
}

if fp.LoadLocalReadTimeout.Duration == 0 {
fp.LoadLocalReadTimeout.Duration = defaultLoadLocalReadTimeout
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/frontend/internal_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ func (ip *internalProtocol) GetU32(id PropertyID) uint32 {
switch id {
case CONNID:
return ip.ConnectionID()
case CAPABILITY:
return 0
}
return math.MaxUint32
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/frontend/mysql_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ type Conn struct {
service string
}

// SetTimeout updates the read timeout used by ReadFromConn.
func (c *Conn) SetTimeout(d time.Duration) {
c.timeout = d
}

// NewIOSession create a new io session
func NewIOSession(conn net.Conn, pu *config.ParameterUnit, service string) (_ *Conn, err error) {
c := &Conn{
Expand Down
3 changes: 2 additions & 1 deletion pkg/frontend/mysql_cmd_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,7 @@ func doShowVariables(ses *Session, execCtx *ExecCtx, sv *tree.ShowVariables) err
}

var err error
useGlobal := sv.Global

col1 := new(MysqlColumn)
col1.SetColumnType(defines.MYSQL_TYPE_VARCHAR)
Expand Down Expand Up @@ -977,7 +978,7 @@ func doShowVariables(ses *Session, execCtx *ExecCtx, sv *tree.ShowVariables) err
}

var value interface{}
if sv.Global {
if useGlobal {
if value, err = ses.GetGlobalSysVar(name); err != nil {
continue
}
Expand Down
52 changes: 52 additions & 0 deletions pkg/frontend/mysql_cmd_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ func Test_mysqlerror(t *testing.T) {
func Test_handleShowVariables(t *testing.T) {
ctx := defines.AttachAccountId(context.TODO(), 0)
convey.Convey("handleShowVariables succ", t, func() {
setSessionAlloc("", NewLeakCheckAllocator())
ctrl := gomock.NewController(t)
defer ctrl.Finish()

Expand Down Expand Up @@ -766,6 +767,26 @@ func Test_handleShowVariables(t *testing.T) {

shv := &tree.ShowVariables{Global: false}
convey.So(handleShowVariables(ses, ec, shv), convey.ShouldBeNil)

// Ensure global shows global value even if session differs.
ses.sesSysVars.Set("interactive_timeout", int64(30100))
ses.gSysVars.Set("interactive_timeout", int64(86400))
stmt, err := parsers.ParseOne(ctx, dialect.MYSQL, "show global variables like 'interactive_timeout'", 1)
convey.So(err, convey.ShouldBeNil)
showVars, ok := stmt.(*tree.ShowVariables)
convey.So(ok, convey.ShouldBeTrue)
ses.SetMysqlResultSet(&MysqlResultSet{})
convey.So(handleShowVariables(ses, ec, showVars), convey.ShouldBeNil)
mrs := ses.GetMysqlResultSet()
found := false
for _, row := range mrs.Data {
if len(row) >= 2 && row[0] == "interactive_timeout" {
convey.So(row[1], convey.ShouldEqual, int64(86400))
found = true
break
}
}
convey.So(found, convey.ShouldBeTrue)
})
}

Expand Down Expand Up @@ -804,6 +825,37 @@ func Test_GetComputationWrapper(t *testing.T) {
})
}

func Test_GetComputationWrapper_ShowVariablesGlobal(t *testing.T) {
convey.Convey("GetComputationWrapper show global variables", t, func() {
sql := "show global variables like 'interactive_timeout'"
var eng engine.Engine
proc := testutil.NewProcessWithMPool(t, "", mpool.MustNewZero())

sysVars := make(map[string]interface{})
for name, sysVar := range gSysVarsDefs {
sysVars[name] = sysVar.Default
}
ses := &Session{planCache: newPlanCache(1),
feSessionImpl: feSessionImpl{
gSysVars: &SystemVariables{mp: sysVars},
},
}

ctrl := gomock.NewController(t)
ec := newTestExecCtx(context.Background(), ctrl)
ec.ses = ses
ec.input = &UserInput{sql: sql}

cws, err := GetComputationWrapper(ec, "", "", eng, proc, ses)
convey.So(err, convey.ShouldBeNil)
convey.So(len(cws) > 0, convey.ShouldBeTrue)
stmt := cws[0].GetAst()
sv, ok := stmt.(*tree.ShowVariables)
convey.So(ok, convey.ShouldBeTrue)
convey.So(sv.Global, convey.ShouldBeTrue)
})
}

func runTestHandle(funName string, t *testing.T, handleFun func(ses *Session) error) {
ctx := context.TODO()
convey.Convey(fmt.Sprintf("%s succ", funName), t, func() {
Expand Down
9 changes: 7 additions & 2 deletions pkg/frontend/mysql_protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ func (mp *MysqlProtocolImpl) GetU32(id PropertyID) uint32 {
switch id {
case CONNID:
return mp.ConnectionID()
case CAPABILITY:
return mp.GetCapability()
}
return math.MaxUint32
}
Expand Down Expand Up @@ -1508,9 +1510,12 @@ func (mp *MysqlProtocolImpl) HandleHandshake(ctx context.Context, payload []byte
return false, moerr.NewInternalError(ctx, "received a broken response packet")
}

if capabilities, _, ok := mp.io.ReadUint16(payload, 0); !ok {
capabilities, _, ok := mp.io.ReadUint16(payload, 0)
if !ok {
return false, moerr.NewInternalError(ctx, "read capabilities from response packet failed")
} else if uint32(capabilities)&CLIENT_PROTOCOL_41 != 0 {
}

if uint32(capabilities)&CLIENT_PROTOCOL_41 != 0 {
var resp41 response41
var ok2 bool
mp.GetSession().Debug(ctx, "analyse handshake response")
Expand Down
86 changes: 86 additions & 0 deletions pkg/frontend/repro_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright 2022 Matrix Origin
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package frontend

import (
"context"
"testing"

"github.com/matrixorigin/matrixone/pkg/config"
"github.com/stretchr/testify/require"
)

func TestInteractiveTimeoutScopeLeak(t *testing.T) {
ctx := context.Background()

// Setup session
sv := &config.FrontendParameters{}
sv.SetDefaultValues()
pu := config.NewParameterUnit(sv, nil, nil, nil)
setPu("", pu)
setSessionAlloc("", NewLeakCheckAllocator())

// Build a minimal Session stub for sysvar scope checks.
// This test does not need network protocol behaviors.
ses := &Session{
feSessionImpl: feSessionImpl{
service: "",
},
}

// Initialize system variables
// We need to mock GSysVarsMgr behavior or just manually init for test
// Since GSysVarsMgr depends on table access, let's manually setup gSysVars and sesSysVars
// to mimic InitSystemVariables behavior

gVars := &SystemVariables{mp: make(map[string]interface{})}
// Default interactive_timeout
gVars.Set("interactive_timeout", int64(86400))

ses.gSysVars = gVars
ses.sesSysVars = ses.gSysVars.Clone()

// Verify initial state
val, err := ses.GetGlobalSysVar("interactive_timeout")
require.NoError(t, err)
require.Equal(t, int64(86400), val)

val, err = ses.GetSessionSysVar("interactive_timeout")
require.NoError(t, err)
require.Equal(t, int64(86400), val)

// 1. Set Session Variable
err = ses.SetSessionSysVar(ctx, "interactive_timeout", int64(30100))
require.NoError(t, err)

// Verify Session updated
val, err = ses.GetSessionSysVar("interactive_timeout")
require.NoError(t, err)
require.Equal(t, int64(30100), val)

// Verify Global UNCHANGED
val, err = ses.GetGlobalSysVar("interactive_timeout")
require.NoError(t, err)
require.Equal(t, int64(86400), val, "Global variable should not change after session set")

// 2. Set Global Variable (should fail)
err = ses.SetGlobalSysVar(ctx, "interactive_timeout", int64(30100))
require.Error(t, err) // Should fail with read-only

// Verify Global UNCHANGED
val, err = ses.GetGlobalSysVar("interactive_timeout")
require.NoError(t, err)
require.Equal(t, int64(86400), val, "Global variable should not change after failed global set")
}
81 changes: 77 additions & 4 deletions pkg/frontend/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,23 +428,54 @@ func (mo *MOServer) handshake(rs *Conn) error {
if err := protocol.Authenticate(tempCtx); err != nil {
return moerr.AttachCause(tempCtx, err)
}
mo.applyInteractiveWaitTimeout(tempCtx, ses, protocol)
protocol.SetBool(TLS_ESTABLISHED, true)
protocol.SetBool(ESTABLISHED, true)
}
} else {
ses.Debugf(tempCtx, "handleHandshake")
_, err = protocol.HandleHandshake(tempCtx, payload)
isTlsHeader, err = protocol.HandleHandshake(tempCtx, payload)
if err != nil {
err = moerr.AttachCause(tempCtx, err)
ses.Error(tempCtx,
"Error occurred",
zap.Error(err))
return err
}
if err = protocol.Authenticate(tempCtx); err != nil {
return moerr.AttachCause(tempCtx, err)
if isTlsHeader {
ts[TSUpgradeTLSStart] = time.Now()
ses.Debugf(tempCtx, "upgrade to TLS")
// do upgradeTls
tlsConn := tls.Server(rs.RawConn(), rm.getTlsConfig())
ses.Debugf(tempCtx, "get TLS conn ok")
tlsCtx, cancelFun := context.WithTimeoutCause(tempCtx, 20*time.Second, moerr.CauseHandshake2)
if err = tlsConn.HandshakeContext(tlsCtx); err != nil {
err = moerr.AttachCause(tlsCtx, err)
ses.Error(tempCtx,
"Error occurred before cancel()",
zap.Error(err))
cancelFun()
ses.Error(tempCtx,
"Error occurred after cancel()",
zap.Error(err))
return err
}
cancelFun()
ses.Debugf(tempCtx, "TLS handshake ok")
rs.UseConn(tlsConn)
ses.Debugf(tempCtx, "TLS handshake finished")

// tls upgradeOk
protocol.SetBool(TLS_ESTABLISHED, true)
ts[TSUpgradeTLSEnd] = time.Now()
v2.UpgradeTLSDurationHistogram.Observe(ts[TSUpgradeTLSEnd].Sub(ts[TSUpgradeTLSStart]).Seconds())
} else {
if err = protocol.Authenticate(tempCtx); err != nil {
return moerr.AttachCause(tempCtx, err)
}
mo.applyInteractiveWaitTimeout(tempCtx, ses, protocol)
protocol.SetBool(ESTABLISHED, true)
}
protocol.SetBool(ESTABLISHED, true)
}
ts[TSEstablishEnd] = time.Now()
v2.EstablishDurationHistogram.Observe(ts[TSEstablishEnd].Sub(ts[TSEstablishStart]).Seconds())
Expand Down Expand Up @@ -636,6 +667,7 @@ func (mo *MOServer) handleRequest(rs *Conn) error {
return io.EOF
}

mo.applyIdleTimeout(rs)
msg, err = rs.Read()
if err != nil {
if err == io.EOF {
Expand All @@ -657,3 +689,44 @@ func (mo *MOServer) handleRequest(rs *Conn) error {
}
return nil
}

func (mo *MOServer) applyIdleTimeout(rs *Conn) {
rm := mo.rm
if rm == nil {
return
}
routine := rm.getRoutine(rs)
if routine == nil {
return
}
ses := routine.getSession()
if ses == nil {
return
}
val, err := ses.GetSessionSysVar("wait_timeout")
if err != nil {
return
}
timeoutSec, ok := val.(int64)
if !ok {
return
}
if timeoutSec <= 0 {
rs.SetTimeout(0)
return
}
rs.SetTimeout(time.Duration(timeoutSec) * time.Second)
}

func (mo *MOServer) applyInteractiveWaitTimeout(ctx context.Context, ses *Session, protocol MysqlRrWr) {
if protocol.GetU32(CAPABILITY)&CLIENT_INTERACTIVE == 0 {
return
}
val, err := ses.GetSessionSysVar("interactive_timeout")
if err != nil {
return
}
if err = ses.SetSessionSysVar(ctx, "wait_timeout", val); err != nil {
ses.Errorf(ctx, "set wait_timeout from interactive_timeout failed: %v", err)
}
}
Loading
Loading