diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 7cae32f..3f303ab 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -14,11 +14,24 @@ import ( ) const ( - defaultReadTimeout = 10 * time.Second + defaultReadTimeout = 10 * time.Second + maxRepeatedQuestion = 2 ) var ErrNotFoundAnswer = errors.New("not found answer") +type QuestionExceptionRepeated struct { + RepeatedQuestionCount int +} + +func (e *QuestionExceptionRepeated) Error() string { + return fmt.Sprintf("repeated question count=%d", e.RepeatedQuestionCount) +} + +func ThrowQuestionExceptionRepeated(repeatedQuestionCount int) error { + return &QuestionExceptionRepeated{RepeatedQuestionCount: repeatedQuestionCount} +} + type Res struct { output []byte error []byte @@ -84,7 +97,7 @@ type Cmd interface { // to call if matched. GetExprCallback() ([]string, map[string]string) // QuestionHandler is called when question is matched. - QuestionHandler(question []byte) ([]byte, error) + QuestionHandler(question []byte, attempt int) ([]byte, error) // GetQuestionExprs returns list of possible questions. GetQuestionExprs() []expr.Expr // ErrorHandler is called when where is an error found in output. @@ -142,7 +155,7 @@ func (m CmdImpl) GetExprCallback() ([]string, map[string]string) { return res, exprToCB } -func (m CmdImpl) QuestionHandler(question []byte) ([]byte, error) { +func (m CmdImpl) QuestionHandler(question []byte, attempt int) ([]byte, error) { for _, cmdAnswer := range m.questionAnswers { ans, ok, err := cmdAnswer.Match(question) if err != nil { @@ -151,6 +164,9 @@ func (m CmdImpl) QuestionHandler(question []byte) ([]byte, error) { if !ok { continue } + if attempt > cmdAnswer.maxAttempts { + return nil, ThrowQuestionExceptionRepeated(attempt) + } if !cmdAnswer.notSendNL { ans = append(ans, []byte("\n")...) } @@ -226,9 +242,10 @@ func WithForwarding(forward bool) CmdOption { } type Answer struct { - question string - answer string - notSendNL bool + question string + answer string + notSendNL bool + maxAttempts int } func (m Answer) Match(question []byte) ([]byte, bool, error) { @@ -266,11 +283,15 @@ func (m Answer) GetExpr() expr.Expr { } func NewAnswer(question, answer string, notSendNL bool) Answer { - return Answer{question: question, answer: answer, notSendNL: notSendNL} + return Answer{question: question, answer: answer, notSendNL: notSendNL, maxAttempts: maxRepeatedQuestion} } func NewAnswerWithNL(question, answer string) Answer { - return Answer{question: question, answer: answer, notSendNL: false} + return Answer{question: question, answer: answer, notSendNL: false, maxAttempts: maxRepeatedQuestion} +} + +func NewAnswerWithNLMaxAttempts(question, answer string, maxAttempts int) Answer { + return Answer{question: question, answer: answer, notSendNL: false, maxAttempts: maxAttempts} } func WithExprCallback(exprCallbacks ...ExprCallback) CmdOption { diff --git a/pkg/device/errors.go b/pkg/device/errors.go index 29946ce..d506d48 100644 --- a/pkg/device/errors.go +++ b/pkg/device/errors.go @@ -57,16 +57,3 @@ func (e *QuestionException) Error() string { func ThrowQuestionException(question []byte) error { return &QuestionException{Question: question} } - -type QuestionExceptionRepeated struct { - Question []byte - RepeatedQuestionCount int -} - -func (e *QuestionExceptionRepeated) Error() string { - return fmt.Sprintf("repeated question: %s, repeated question count: %d", e.Question, e.RepeatedQuestionCount) -} - -func ThrowQuestionExceptionRepeated(question []byte, repeatedQuestionCount int) error { - return &QuestionExceptionRepeated{Question: question, RepeatedQuestionCount: repeatedQuestionCount} -} diff --git a/pkg/device/genericcli/genericcli.go b/pkg/device/genericcli/genericcli.go index 4ec016b..cbd6c64 100644 --- a/pkg/device/genericcli/genericcli.go +++ b/pkg/device/genericcli/genericcli.go @@ -26,7 +26,6 @@ var ErrorCLILogin = errors.New("CLI login is not supported") const AnyNLPattern = `(\r\n|\n)` const DefaultCLIConnectTimeout = 15 * time.Second -const maxRepeatedQuestion = 2 const ( promptExprName = "prompt" @@ -725,11 +724,8 @@ func GenericExecute(command cmd.Cmd, connector streamer.Connector, cli GenericCL repeatedQuestionCount = 1 lastQuestion = question } - if repeatedQuestionCount > maxRepeatedQuestion { - return nil, device.ThrowQuestionExceptionRepeated(question, repeatedQuestionCount) - } logger.Debug("QuestionHandler question", zap.ByteString("question", question)) - answer, err := command.QuestionHandler(question) + answer, err := command.QuestionHandler(question, repeatedQuestionCount) if err != nil { if errors.Is(err, cmd.ErrNotFoundAnswer) { return nil, device.ThrowQuestionException(question) diff --git a/pkg/device/genericcli/gop3_repeated_question_test.go b/pkg/device/genericcli/gop3_repeated_question_test.go index 205f8ed..e40ad87 100644 --- a/pkg/device/genericcli/gop3_repeated_question_test.go +++ b/pkg/device/genericcli/gop3_repeated_question_test.go @@ -48,10 +48,52 @@ func TestRepeatedQuestionAborts(t *testing.T) { require.NoError(t, err) require.NoError(t, serverErr) require.Error(t, resErr) - var qErr *device.QuestionExceptionRepeated + var qErr *cmd.QuestionExceptionRepeated require.ErrorAs(t, resErr, &qErr) } +func TestRepeatedQuestionCustomAttempts(t *testing.T) { + logger := zap.Must(zap.NewDevelopmentConfig().Build()) + + dialog := [][]gmock.Action{ + { + gmock.Send(""), + gmock.Expect("enable\n"), + gmock.SendEcho("enable\r\n"), + gmock.Send("Password:"), + gmock.Expect("mypass\n"), + gmock.Send("Error: Incorrect password.\r\nPassword:"), + gmock.Expect("mypass\n"), + gmock.Send("Error: Incorrect password.\r\nPassword:"), + gmock.Expect("mypass\n"), + gmock.Send("Error: Incorrect password.\r\nPassword:"), + gmock.Expect("mypass\n"), + gmock.Send("Error: Incorrect password.\r\n"), + gmock.Close(), + }, + } + + actions := gmock.ConcatMultipleSlices(dialog) + cmds := []cmd.Cmd{ + cmd.NewCmd("enable", cmd.WithAddAnswers( + cmd.NewAnswerWithNLMaxAttempts("Password:", "mypass", 100), + )), + } + + cm, resErr, serverErr, err := gmock.RunCmd(func(connector streamer.Connector) device.Device { + dev := newDevice(fullQuestion, connector, logger) + return &dev + }, actions, cmds, logger) + + require.NoError(t, err) + require.NoError(t, serverErr) + require.NoError(t, resErr) + require.Len(t, cm, 1) + require.Empty(t, cm[0].Output()) + require.Equal(t, "Error: Incorrect password.", string(cm[0].Error())) + require.Equal(t, 1, cm[0].Status()) +} + func TestDifferentQuestionsDoNotTriggerLimit(t *testing.T) { logger := zap.Must(zap.NewDevelopmentConfig().Build())