diff --git a/channel.go b/channel.go index e1da7b5..abf2f2b 100644 --- a/channel.go +++ b/channel.go @@ -9,21 +9,15 @@ import ( // PoolConfig 连接池相关配置 type PoolConfig struct { - //连接池中拥有的最小连接数 - InitialCap int - //连接池中拥有的最大的连接数 - MaxCap int - //生成连接的方法 - Factory func() (interface{}, error) - //关闭连接的方法 - Close func(interface{}) error - //检查连接是否有效的方法 - Ping func(interface{}) error - //连接最大空闲时间,超过该事件则将失效 - IdleTimeout time.Duration + InitialCap int // 连接池中拥有的最小连接数 + MaxCap int // 连接池中拥有的最大的连接数 + Factory func() (interface{}, error) // 生成连接的方法 + Close func(interface{}) error // 关闭连接的方法 + Ping func(interface{}) error // 检查连接是否有效的方法 + IdleTimeout time.Duration // 连接最大空闲时间,超过该事件则将失效 } -//channelPool 存放连接信息 +// channelPool 存放连接信息 type channelPool struct { mu sync.Mutex conns chan *idleConn @@ -38,7 +32,7 @@ type idleConn struct { t time.Time } -//NewChannelPool 初始化连接 +// NewChannelPool 初始化连接 func NewChannelPool(poolConfig *PoolConfig) (Pool, error) { if poolConfig.InitialCap < 0 || poolConfig.MaxCap <= 0 || poolConfig.InitialCap > poolConfig.MaxCap { return nil, errors.New("invalid capacity settings") @@ -73,7 +67,7 @@ func NewChannelPool(poolConfig *PoolConfig) (Pool, error) { return c, nil } -//getConns 获取所有连接 +// getConns 获取所有连接 func (c *channelPool) getConns() chan *idleConn { c.mu.Lock() conns := c.conns @@ -81,7 +75,7 @@ func (c *channelPool) getConns() chan *idleConn { return conns } -//Get 从pool中取一个连接 +// Get 从pool中取一个连接 func (c *channelPool) Get() (interface{}, error) { conns := c.getConns() if conns == nil { @@ -93,7 +87,7 @@ func (c *channelPool) Get() (interface{}, error) { if wrapConn == nil { return nil, ErrClosed } - //判断是否超时,超时则丢弃 + // 判断是否超时,超时则丢弃 if timeout := c.idleTimeout; timeout > 0 { if wrapConn.t.Add(timeout).Before(time.Now()) { //丢弃并关闭该连接 @@ -101,7 +95,7 @@ func (c *channelPool) Get() (interface{}, error) { continue } } - //判断是否失效,失效则丢弃,如果用户没有设定 ping 方法,就不检查 + // 判断是否失效,失效则丢弃,如果用户没有设定 ping 方法,就不检查 if c.ping != nil { if err := c.Ping(wrapConn.conn); err != nil { fmt.Println("conn is not able to be connected: ", err) @@ -110,17 +104,24 @@ func (c *channelPool) Get() (interface{}, error) { } return wrapConn.conn, nil default: - conn, err := c.factory() - if err != nil { - return nil, err - } - - return conn, nil + return c.Connect() } } } -//Put 将连接放回pool中 +// Conn 重新创建一个连接 +func (c *channelPool) Connect() (interface{}, error) { + if c.factory == nil { + return nil, errors.New("factory func is nil. rejecting") + } + conn, err := c.factory() + if err != nil { + return nil, err + } + return conn, nil +} + +// Put 将连接放回pool中 func (c *channelPool) Put(conn interface{}) error { if conn == nil { return errors.New("connection is nil. rejecting") @@ -142,28 +143,35 @@ func (c *channelPool) Put(conn interface{}) error { } } -//Close 关闭单条连接 +// Close 关闭单条连接 func (c *channelPool) Close(conn interface{}) error { if conn == nil { return errors.New("connection is nil. rejecting") } + if c.close == nil { + return errors.New("close func is nil. rejecting") + } return c.close(conn) } -//Ping 检查单条连接是否有效 +// Ping 检查单条连接是否有效 func (c *channelPool) Ping(conn interface{}) error { if conn == nil { return errors.New("connection is nil. rejecting") } + if c.ping == nil { + return errors.New("ping func is nil. rejecting") + } return c.ping(conn) } -//Release 释放连接池中所有连接 +// Release 释放连接池中所有连接 func (c *channelPool) Release() { c.mu.Lock() conns := c.conns c.conns = nil c.factory = nil + c.ping = nil closeFun := c.close c.close = nil c.mu.Unlock() @@ -178,7 +186,7 @@ func (c *channelPool) Release() { } } -//Len 连接池中已有的连接 +// Len 连接池中已有的连接 func (c *channelPool) Len() int { return len(c.getConns()) } diff --git a/pool.go b/pool.go index 7ec40ff..23680ca 100644 --- a/pool.go +++ b/pool.go @@ -3,14 +3,16 @@ package pool import "errors" var ( - //ErrClosed 连接池已经关闭Error + // ErrClosed 连接池已经关闭Error ErrClosed = errors.New("pool is closed") ) -//Pool 基本方法 +// Pool 基本方法 type Pool interface { Get() (interface{}, error) + Connect() (interface{}, error) + Put(interface{}) error Close(interface{}) error