diff --git a/cluster.go b/cluster.go index cc97604..9028ff1 100644 --- a/cluster.go +++ b/cluster.go @@ -62,6 +62,7 @@ func newProximityCache() *proximityCache { // Cluster holds the information about the state of the network. It is the main interface to the distributed network of Nodes. type Cluster struct { self *Node + transport Transport table *routingTable leafset *leafSet neighborhoodset *neighborhoodSet @@ -200,9 +201,10 @@ func (c *Cluster) SetNetworkTimeout(timeout int) { } // NewCluster creates a new instance of a connection to the network and intialises the state tables and channels it requires. -func NewCluster(self *Node, credentials Credentials) *Cluster { +func NewCluster(self *Node, credentials Credentials, transport Transport) *Cluster { return &Cluster{ self: self, + transport: transport, table: newRoutingTable(self), leafset: newLeafSet(self), neighborhoodset: newNeighborhoodSet(self), @@ -259,7 +261,7 @@ func (c *Cluster) RegisterCallback(app Application) { func (c *Cluster) Listen() error { portstr := strconv.Itoa(c.self.Port) c.debug("Listening on port %d", c.self.Port) - ln, err := net.Listen("tcp", ":"+portstr) + ln, err := c.transport.Listen(":" + portstr) if err != nil { return err } @@ -514,7 +516,7 @@ func (c *Cluster) send(msg Message, destination *Node) error { // SendToIP sends a message directly to an IP using the Wendy networking logic. func (c *Cluster) SendToIP(msg Message, address string) error { c.debug("Sending message %s", string(msg.Value)) - conn, err := net.DialTimeout("tcp", address, time.Duration(c.getNetworkTimeout())*time.Second) + conn, err := c.transport.DialTimeout(address, time.Duration(c.getNetworkTimeout())*time.Second) if err != nil { c.debug(err.Error()) return deadNodeError diff --git a/integration_test.go b/integration_test.go index c976078..0caf407 100644 --- a/integration_test.go +++ b/integration_test.go @@ -85,7 +85,7 @@ func makeCluster(idBytes string) (*Cluster, error) { return nil, err } node := NewNode(id, "127.0.0.1", "127.0.0.1", "testing", 0) - cluster := NewCluster(node, nil) + cluster := NewCluster(node, nil, NewTCPTransport()) cluster.SetHeartbeatFrequency(10) cluster.SetNetworkTimeout(1) cluster.SetLogLevel(LogLevelDebug) diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..95fe6ce --- /dev/null +++ b/transport.go @@ -0,0 +1,12 @@ +package wendy + +import ( + "net" + "time" +) + +// Transport represents a low-level network interface +type Transport interface { + Listen(laddr string) (net.Listener, error) + DialTimeout(addr string, timeout time.Duration) (net.Conn, error) +} diff --git a/transport_tcp.go b/transport_tcp.go new file mode 100644 index 0000000..ff28930 --- /dev/null +++ b/transport_tcp.go @@ -0,0 +1,22 @@ +package wendy + +import ( + "net" + "time" +) + +// NewTCPTransport returns a new TCP transport +func NewTCPTransport() Transport { + return &tcpTransport{} +} + +type tcpTransport struct { +} + +func (t *tcpTransport) Listen(laddr string) (net.Listener, error) { + return net.Listen("tcp", laddr) +} + +func (t *tcpTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("tcp", addr, timeout) +} diff --git a/transport_tcp_test.go b/transport_tcp_test.go new file mode 100644 index 0000000..9d7d7e9 --- /dev/null +++ b/transport_tcp_test.go @@ -0,0 +1,73 @@ +package wendy + +import ( + "bytes" + "io" + "testing" + "time" +) + +func TestTCPTransport(t *testing.T) { + baton := make(chan struct{}, 1) + + go func() { + transport := NewTCPTransport() + l, err := transport.Listen("0.0.0.0:2999") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + baton <- struct{}{} + + conn, err := l.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + <-baton + + var buf [1024]byte + n, err := conn.Read(buf[:]) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf[:n], []byte("Hello Word")) { + t.Fatalf("expected %q instead of %q", "Hello World", buf[:n]) + } + + err = conn.Close() + if err != nil { + t.Fatal(err) + } + + baton <- struct{}{} + }() + + func() { + transport := NewTCPTransport() + <-baton + + conn, err := transport.DialTimeout("127.0.0.1:2999", 10*time.Second) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + _, err = io.WriteString(conn, "Hello World") + if err != nil { + t.Fatal(err) + } + + baton <- struct{}{} + + err = conn.Close() + if err != nil { + t.Fatal(err) + } + }() + + <-baton +}