Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions pkg/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@ import (
)

const (
defaultReadTimeout = 10 * time.Second
defaultReadTimeout = 10 * time.Second
maxRepeatedQuestion = 2
)

var ErrNotFoundAnswer = errors.New("not found answer")

type QuestionExceptionRepeated struct {
RepeatedQuestionCount int
}

func (e *QuestionExceptionRepeated) Error() string {
return fmt.Sprintf("repeated question count=%d", e.RepeatedQuestionCount)
}

func ThrowQuestionExceptionRepeated(repeatedQuestionCount int) error {
return &QuestionExceptionRepeated{RepeatedQuestionCount: repeatedQuestionCount}
}

type Res struct {
output []byte
error []byte
Expand Down Expand Up @@ -84,7 +97,7 @@ type Cmd interface {
// to call if matched.
GetExprCallback() ([]string, map[string]string)
// QuestionHandler is called when question is matched.
QuestionHandler(question []byte) ([]byte, error)
QuestionHandler(question []byte, attempt int) ([]byte, error)
// GetQuestionExprs returns list of possible questions.
GetQuestionExprs() []expr.Expr
// ErrorHandler is called when where is an error found in output.
Expand Down Expand Up @@ -142,7 +155,7 @@ func (m CmdImpl) GetExprCallback() ([]string, map[string]string) {
return res, exprToCB
}

func (m CmdImpl) QuestionHandler(question []byte) ([]byte, error) {
func (m CmdImpl) QuestionHandler(question []byte, attempt int) ([]byte, error) {
for _, cmdAnswer := range m.questionAnswers {
ans, ok, err := cmdAnswer.Match(question)
if err != nil {
Expand All @@ -151,6 +164,9 @@ func (m CmdImpl) QuestionHandler(question []byte) ([]byte, error) {
if !ok {
continue
}
if attempt > cmdAnswer.maxAttempts {
return nil, ThrowQuestionExceptionRepeated(attempt)
}
if !cmdAnswer.notSendNL {
ans = append(ans, []byte("\n")...)
}
Expand Down Expand Up @@ -226,9 +242,10 @@ func WithForwarding(forward bool) CmdOption {
}

type Answer struct {
question string
answer string
notSendNL bool
question string
answer string
notSendNL bool
maxAttempts int
}

func (m Answer) Match(question []byte) ([]byte, bool, error) {
Expand Down Expand Up @@ -266,11 +283,15 @@ func (m Answer) GetExpr() expr.Expr {
}

func NewAnswer(question, answer string, notSendNL bool) Answer {
return Answer{question: question, answer: answer, notSendNL: notSendNL}
return Answer{question: question, answer: answer, notSendNL: notSendNL, maxAttempts: maxRepeatedQuestion}
}

func NewAnswerWithNL(question, answer string) Answer {
return Answer{question: question, answer: answer, notSendNL: false}
return Answer{question: question, answer: answer, notSendNL: false, maxAttempts: maxRepeatedQuestion}
}

func NewAnswerWithNLMaxAttempts(question, answer string, maxAttempts int) Answer {
return Answer{question: question, answer: answer, notSendNL: false, maxAttempts: maxAttempts}
}

func WithExprCallback(exprCallbacks ...ExprCallback) CmdOption {
Expand Down
13 changes: 0 additions & 13 deletions pkg/device/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,3 @@ func (e *QuestionException) Error() string {
func ThrowQuestionException(question []byte) error {
return &QuestionException{Question: question}
}

type QuestionExceptionRepeated struct {
Question []byte
RepeatedQuestionCount int
}

func (e *QuestionExceptionRepeated) Error() string {
return fmt.Sprintf("repeated question: %s, repeated question count: %d", e.Question, e.RepeatedQuestionCount)
}

func ThrowQuestionExceptionRepeated(question []byte, repeatedQuestionCount int) error {
return &QuestionExceptionRepeated{Question: question, RepeatedQuestionCount: repeatedQuestionCount}
}
6 changes: 1 addition & 5 deletions pkg/device/genericcli/genericcli.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ var ErrorCLILogin = errors.New("CLI login is not supported")

const AnyNLPattern = `(\r\n|\n)`
const DefaultCLIConnectTimeout = 15 * time.Second
const maxRepeatedQuestion = 2

const (
promptExprName = "prompt"
Expand Down Expand Up @@ -725,11 +724,8 @@ func GenericExecute(command cmd.Cmd, connector streamer.Connector, cli GenericCL
repeatedQuestionCount = 1
lastQuestion = question
}
if repeatedQuestionCount > maxRepeatedQuestion {
return nil, device.ThrowQuestionExceptionRepeated(question, repeatedQuestionCount)
}
logger.Debug("QuestionHandler question", zap.ByteString("question", question))
answer, err := command.QuestionHandler(question)
answer, err := command.QuestionHandler(question, repeatedQuestionCount)
if err != nil {
if errors.Is(err, cmd.ErrNotFoundAnswer) {
return nil, device.ThrowQuestionException(question)
Expand Down
44 changes: 43 additions & 1 deletion pkg/device/genericcli/gop3_repeated_question_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,52 @@ func TestRepeatedQuestionAborts(t *testing.T) {
require.NoError(t, err)
require.NoError(t, serverErr)
require.Error(t, resErr)
var qErr *device.QuestionExceptionRepeated
var qErr *cmd.QuestionExceptionRepeated
require.ErrorAs(t, resErr, &qErr)
}

func TestRepeatedQuestionCustomAttempts(t *testing.T) {
logger := zap.Must(zap.NewDevelopmentConfig().Build())

dialog := [][]gmock.Action{
{
gmock.Send("<device>"),
gmock.Expect("enable\n"),
gmock.SendEcho("enable\r\n"),
gmock.Send("Password:"),
gmock.Expect("mypass\n"),
gmock.Send("Error: Incorrect password.\r\nPassword:"),
gmock.Expect("mypass\n"),
gmock.Send("Error: Incorrect password.\r\nPassword:"),
gmock.Expect("mypass\n"),
gmock.Send("Error: Incorrect password.\r\nPassword:"),
gmock.Expect("mypass\n"),
gmock.Send("Error: Incorrect password.\r\n<device>"),
gmock.Close(),
},
}

actions := gmock.ConcatMultipleSlices(dialog)
cmds := []cmd.Cmd{
cmd.NewCmd("enable", cmd.WithAddAnswers(
cmd.NewAnswerWithNLMaxAttempts("Password:", "mypass", 100),
)),
}

cm, resErr, serverErr, err := gmock.RunCmd(func(connector streamer.Connector) device.Device {
dev := newDevice(fullQuestion, connector, logger)
return &dev
}, actions, cmds, logger)

require.NoError(t, err)
require.NoError(t, serverErr)
require.NoError(t, resErr)
require.Len(t, cm, 1)
require.Empty(t, cm[0].Output())
require.Equal(t, "Error: Incorrect password.", string(cm[0].Error()))
require.Equal(t, 1, cm[0].Status())
}

func TestDifferentQuestionsDoNotTriggerLimit(t *testing.T) {
logger := zap.Must(zap.NewDevelopmentConfig().Build())

Expand Down
Loading