diff --git a/auto.go b/auto.go index f3bece3..8314121 100644 --- a/auto.go +++ b/auto.go @@ -117,6 +117,19 @@ func hashValueReply(v HashValue) (*MultiBulkReply, error) { return MultiBulkFromMap(m), nil } +func indexValueReply(v map[int][]byte) (*MultiBulkReply, error) { + fmt.Println(v) + fmt.Println(len(v)) + m := make([]interface{}, len(v)*2) + i := 0 + for k, v := range v { + m[i] = v + m[i+1] = k + i += 2 + } + return &MultiBulkReply{values: m}, nil +} + func (srv *Server) createReply(r *Request, val interface{}) (ReplyWriter, error) { Debugf("CREATE REPLY: %T", val) switch v := val.(type) { @@ -139,6 +152,8 @@ func (srv *Server) createReply(r *Request, val interface{}) (ReplyWriter, error) return hashValueReply(v) case map[string][]byte: return hashValueReply(v) + case map[int][]byte: + return indexValueReply(v) case map[string]interface{}: return MultiBulkFromMap(v), nil case int: @@ -156,7 +171,7 @@ func (srv *Server) createReply(r *Request, val interface{}) (ReplyWriter, error) case *MultiChannelWriter: println("New client") for _, mcw := range v.Chans { - mcw.clientChan = r.ClientChan + mcw.ClientChan = r.ClientChan } return v, nil default: diff --git a/parser.go b/parser.go index caaf49c..89ad99a 100644 --- a/parser.go +++ b/parser.go @@ -4,7 +4,7 @@ import ( "bufio" "fmt" "io" - "io/ioutil" + "strconv" "strings" ) @@ -13,6 +13,7 @@ func parseRequest(conn io.ReadCloser) (*Request, error) { // first line of redis request should be: // *CRLF line, err := r.ReadString('\n') + // fmt.Println(line) if err != nil { return nil, err } @@ -21,7 +22,8 @@ func parseRequest(conn io.ReadCloser) (*Request, error) { // Multiline request: if line[0] == '*' { - if _, err := fmt.Sscanf(line, "*%d\r", &argsCount); err != nil { + argsCount, err = strconv.Atoi(strings.Trim(line, "* \r\n")) + if err != nil { return nil, malformed("*", line) } // All next lines are pairs of: @@ -32,14 +34,12 @@ func parseRequest(conn io.ReadCloser) (*Request, error) { if err != nil { return nil, err } - args := make([][]byte, argsCount-1) for i := 0; i < argsCount-1; i += 1 { if args[i], err = readArgument(r); err != nil { return nil, err } } - return &Request{ Name: strings.ToLower(string(firstArg)), Args: args, @@ -56,6 +56,7 @@ func parseRequest(conn io.ReadCloser) (*Request, error) { args = append(args, []byte(arg)) } } + fmt.Println(strings.ToLower(string(fields[0]))) return &Request{ Name: strings.ToLower(string(fields[0])), Args: args, @@ -71,19 +72,17 @@ func readArgument(r *bufio.Reader) ([]byte, error) { return nil, malformed("$", line) } var argSize int - if _, err := fmt.Sscanf(line, "$%d\r", &argSize); err != nil { + argSize, err = strconv.Atoi(strings.Trim(line, "$ \r\n")) + if err != nil { return nil, malformed("$", line) } // I think int is safe here as the max length of request // should be less then max int value? - data, err := ioutil.ReadAll(io.LimitReader(r, int64(argSize))) + data := make([]byte, argSize) + n, err := io.ReadFull(r, data) if err != nil { - return nil, err - } - - if len(data) != argSize { - return nil, malformedLength(argSize, len(data)) + return nil, malformedLength(argSize, n) } // Now check for trailing CR diff --git a/reply.go b/reply.go index fa9485a..23ae7bf 100644 --- a/reply.go +++ b/reply.go @@ -180,7 +180,7 @@ func (c *MultiChannelWriter) WriteTo(w io.Writer) (n int64, err error) { type ChannelWriter struct { FirstReply []interface{} Channel chan []interface{} - clientChan chan struct{} + ClientChan chan struct{} } func (c *ChannelWriter) WriteTo(w io.Writer) (int64, error) { @@ -191,7 +191,7 @@ func (c *ChannelWriter) WriteTo(w io.Writer) (int64, error) { for { select { - case <-c.clientChan: + case <-c.ClientChan: return totalBytes, err case reply := <-c.Channel: if reply == nil { diff --git a/server.go b/server.go index c4a4131..e272b6e 100644 --- a/server.go +++ b/server.go @@ -65,7 +65,7 @@ func (srv *Server) ServeClient(conn net.Conn) (err error) { clientChan := make(chan struct{}) // Read on `conn` in order to detect client disconnect - go func() { + defer func() { // Close chan in order to trigger eventual selects defer close(clientChan) defer Debugf("Client disconnected") @@ -106,6 +106,55 @@ func (srv *Server) ServeClient(conn net.Conn) (err error) { return nil } +func (srv *Server) ServeReplClient(conn net.Conn) (err error) { + defer func() { + if err != nil { + fmt.Fprintf(conn, "-%s\n", err) + } + conn.Close() + }() + + clientChan := make(chan struct{}) + + // Read on `conn` in order to detect client disconnect + defer func() { + // Close chan in order to trigger eventual selects + defer close(clientChan) + defer Debugf("Client disconnected") + // FIXME: move conn within the request. + if false { + io.Copy(ioutil.Discard, conn) + } + }() + + var clientAddr string + + switch co := conn.(type) { + case *net.UnixConn: + f, err := conn.(*net.UnixConn).File() + if err != nil { + return err + } + clientAddr = f.Name() + default: + clientAddr = co.RemoteAddr().String() + } + + for { + request, err := parseRequest(conn) + if err != nil { + return err + } + request.Host = clientAddr + request.ClientChan = clientChan + _, err = srv.Apply(request) + if err != nil { + return err + } + } + return nil +} + func NewServer(c *Config) (*Server, error) { srv := &Server{ Proto: c.proto,