From eebad6c75221c00c2b64dd9fdfed99a8f9a67ccb Mon Sep 17 00:00:00 2001 From: domikedos Date: Tue, 17 Jun 2025 15:27:33 +0600 Subject: [PATCH 1/2] implement /state endpoint --- api/claim-api.yml | 22 +++++ pkg/api/handler.go | 20 ++++ pkg/api/oas/oas_client_gen.go | 74 +++++++++++++++ pkg/api/oas/oas_handlers_gen.go | 102 ++++++++++++++++++++ pkg/api/oas/oas_json_gen.go | 113 +++++++++++++++++++++++ pkg/api/oas/oas_response_decoders_gen.go | 92 ++++++++++++++++++ pkg/api/oas/oas_response_encoders_gen.go | 14 +++ pkg/api/oas/oas_router_gen.go | 46 +++++++++ pkg/api/oas/oas_schemas_gen.go | 26 ++++++ pkg/api/oas/oas_server_gen.go | 4 + pkg/api/oas/oas_unimplemented_gen.go | 7 ++ pkg/api/oas/oas_validators_gen.go | 23 +++++ pkg/prover/enumerate.go | 51 ++++++++++ pkg/prover/enumerate_test.go | 40 ++++++++ pkg/prover/prover.go | 38 ++++++++ 15 files changed, 672 insertions(+) create mode 100644 pkg/prover/enumerate_test.go diff --git a/api/claim-api.yml b/api/claim-api.yml index 6e45720..dcb98d0 100644 --- a/api/claim-api.yml +++ b/api/claim-api.yml @@ -63,6 +63,18 @@ paths: $ref: '#/components/schemas/WalletList' 'default': $ref: '#/components/responses/Error' + /state: + get: + operationId: getState + responses: + '200': + description: TBD + content: + application/json: + schema: + $ref: '#/components/schemas/State' + 'default': + $ref: '#/components/responses/Error' components: schemas: @@ -125,6 +137,16 @@ components: type: string expired_at: type: string + State: + type: object + required: + - total_wallets + - master_address + properties: + total_wallets: + type: number + master_address: + type: string responses: Error: diff --git a/pkg/api/handler.go b/pkg/api/handler.go index e52759c..8196064 100644 --- a/pkg/api/handler.go +++ b/pkg/api/handler.go @@ -396,3 +396,23 @@ func (h *Handler) jettonMasterState(accountID ton.AccountID) (state [2]string, o state, ok = h.jettonMasterStateCache[accountID] return } + +func (h *Handler) GetState(ctx context.Context) (*oas.State, error) { + ch := make(chan prover.AccountCountResponse, 1) + h.prover.Queue() <- prover.AccountCountRequest{ + ResponseCh: ch, + } + select { + case <-ctx.Done(): + return nil, BadRequest("timeout") + case resp := <-ch: + if resp.Err != nil { + return nil, InternalError(resp.Err) + } + masterAddress := h.jettonMaster.ToRaw() + return &oas.State{ + TotalWallets: float64(resp.AccountsCount), + MasterAddress: masterAddress, + }, nil + } +} diff --git a/pkg/api/oas/oas_client_gen.go b/pkg/api/oas/oas_client_gen.go index 64f0824..6baf373 100644 --- a/pkg/api/oas/oas_client_gen.go +++ b/pkg/api/oas/oas_client_gen.go @@ -27,6 +27,10 @@ type Invoker interface { // // GET / GetApiInfo(ctx context.Context) (GetApiInfoOK, error) + // GetState invokes getState operation. + // + // GET /state + GetState(ctx context.Context) (*State, error) // GetWalletInfo invokes getWalletInfo operation. // // GET /wallet/{address} @@ -159,6 +163,76 @@ func (c *Client) sendGetApiInfo(ctx context.Context) (res GetApiInfoOK, err erro return result, nil } +// GetState invokes getState operation. +// +// GET /state +func (c *Client) GetState(ctx context.Context) (*State, error) { + res, err := c.sendGetState(ctx) + return res, err +} + +func (c *Client) sendGetState(ctx context.Context) (res *State, err error) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("getState"), + semconv.HTTPMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/state"), + } + + // Run stopwatch. + startTime := time.Now() + defer func() { + // Use floating point division here for higher precision (instead of Millisecond method). + elapsedDuration := time.Since(startTime) + c.duration.Record(ctx, float64(float64(elapsedDuration)/float64(time.Millisecond)), metric.WithAttributes(otelAttrs...)) + }() + + // Increment request counter. + c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + + // Start a span for this request. + ctx, span := c.cfg.Tracer.Start(ctx, "GetState", + trace.WithAttributes(otelAttrs...), + clientSpanKind, + ) + // Track stage for error reporting. + var stage string + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + } + span.End() + }() + + stage = "BuildURL" + u := uri.Clone(c.requestURL(ctx)) + var pathParts [1]string + pathParts[0] = "/state" + uri.AddPathParts(u, pathParts[:]...) + + stage = "EncodeRequest" + r, err := ht.NewRequest(ctx, "GET", u) + if err != nil { + return res, errors.Wrap(err, "create request") + } + + stage = "SendRequest" + resp, err := c.cfg.Client.Do(r) + if err != nil { + return res, errors.Wrap(err, "do request") + } + defer resp.Body.Close() + + stage = "DecodeResponse" + result, err := decodeGetStateResponse(resp) + if err != nil { + return res, errors.Wrap(err, "decode response") + } + + return result, nil +} + // GetWalletInfo invokes getWalletInfo operation. // // GET /wallet/{address} diff --git a/pkg/api/oas/oas_handlers_gen.go b/pkg/api/oas/oas_handlers_gen.go index 4980a80..77f0ed2 100644 --- a/pkg/api/oas/oas_handlers_gen.go +++ b/pkg/api/oas/oas_handlers_gen.go @@ -122,6 +122,108 @@ func (s *Server) handleGetApiInfoRequest(args [0]string, argsEscaped bool, w htt } } +// handleGetStateRequest handles getState operation. +// +// GET /state +func (s *Server) handleGetStateRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("getState"), + semconv.HTTPMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/state"), + } + + // Start a span for this request. + ctx, span := s.cfg.Tracer.Start(r.Context(), "GetState", + trace.WithAttributes(otelAttrs...), + serverSpanKind, + ) + defer span.End() + + // Add Labeler to context. + labeler := &Labeler{attrs: otelAttrs} + ctx = contextWithLabeler(ctx, labeler) + + // Run stopwatch. + startTime := time.Now() + defer func() { + elapsedDuration := time.Since(startTime) + attrOpt := metric.WithAttributeSet(labeler.AttributeSet()) + + // Increment request counter. + s.requests.Add(ctx, 1, attrOpt) + + // Use floating point division here for higher precision (instead of Millisecond method). + s.duration.Record(ctx, float64(float64(elapsedDuration)/float64(time.Millisecond)), attrOpt) + }() + + var ( + recordError = func(stage string, err error) { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + s.errors.Add(ctx, 1, metric.WithAttributeSet(labeler.AttributeSet())) + } + err error + ) + + var response *State + if m := s.cfg.Middleware; m != nil { + mreq := middleware.Request{ + Context: ctx, + OperationName: "GetState", + OperationSummary: "", + OperationID: "getState", + Body: nil, + Params: middleware.Parameters{}, + Raw: r, + } + + type ( + Request = struct{} + Params = struct{} + Response = *State + ) + response, err = middleware.HookMiddleware[ + Request, + Params, + Response, + ]( + m, + mreq, + nil, + func(ctx context.Context, request Request, params Params) (response Response, err error) { + response, err = s.h.GetState(ctx) + return response, err + }, + ) + } else { + response, err = s.h.GetState(ctx) + } + if err != nil { + if errRes, ok := errors.Into[*ErrorStatusCode](err); ok { + if err := encodeErrorResponse(errRes, w, span); err != nil { + defer recordError("Internal", err) + } + return + } + if errors.Is(err, ht.ErrNotImplemented) { + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + if err := encodeErrorResponse(s.h.NewError(ctx, err), w, span); err != nil { + defer recordError("Internal", err) + } + return + } + + if err := encodeGetStateResponse(response, w, span); err != nil { + defer recordError("EncodeResponse", err) + if !errors.Is(err, ht.ErrInternalServerErrorResponse) { + s.cfg.ErrorHandler(ctx, w, r, err) + } + return + } +} + // handleGetWalletInfoRequest handles getWalletInfo operation. // // GET /wallet/{address} diff --git a/pkg/api/oas/oas_json_gen.go b/pkg/api/oas/oas_json_gen.go index f08d320..7172fa8 100644 --- a/pkg/api/oas/oas_json_gen.go +++ b/pkg/api/oas/oas_json_gen.go @@ -176,6 +176,119 @@ func (s *OptWalletInfoCompressedInfo) UnmarshalJSON(data []byte) error { return s.Decode(d) } +// Encode implements json.Marshaler. +func (s *State) Encode(e *jx.Encoder) { + e.ObjStart() + s.encodeFields(e) + e.ObjEnd() +} + +// encodeFields encodes fields. +func (s *State) encodeFields(e *jx.Encoder) { + { + e.FieldStart("total_wallets") + e.Float64(s.TotalWallets) + } + { + e.FieldStart("master_address") + e.Str(s.MasterAddress) + } +} + +var jsonFieldsNameOfState = [2]string{ + 0: "total_wallets", + 1: "master_address", +} + +// Decode decodes State from json. +func (s *State) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode State to nil") + } + var requiredBitSet [1]uint8 + + if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { + switch string(k) { + case "total_wallets": + requiredBitSet[0] |= 1 << 0 + if err := func() error { + v, err := d.Float64() + s.TotalWallets = float64(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"total_wallets\"") + } + case "master_address": + requiredBitSet[0] |= 1 << 1 + if err := func() error { + v, err := d.Str() + s.MasterAddress = string(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"master_address\"") + } + default: + return d.Skip() + } + return nil + }); err != nil { + return errors.Wrap(err, "decode State") + } + // Validate required fields. + var failures []validate.FieldError + for i, mask := range [1]uint8{ + 0b00000011, + } { + if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { + // Mask only required fields and check equality to mask using XOR. + // + // If XOR result is not zero, result is not equal to expected, so some fields are missed. + // Bits of fields which would be set are actually bits of missed fields. + missed := bits.OnesCount8(result) + for bitN := 0; bitN < missed; bitN++ { + bitIdx := bits.TrailingZeros8(result) + fieldIdx := i*8 + bitIdx + var name string + if fieldIdx < len(jsonFieldsNameOfState) { + name = jsonFieldsNameOfState[fieldIdx] + } else { + name = strconv.Itoa(fieldIdx) + } + failures = append(failures, validate.FieldError{ + Name: name, + Error: validate.ErrFieldRequired, + }) + // Reset bit. + result &^= 1 << bitIdx + } + } + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s *State) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *State) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} + // Encode implements json.Marshaler. func (s *WalletInfo) Encode(e *jx.Encoder) { e.ObjStart() diff --git a/pkg/api/oas/oas_response_decoders_gen.go b/pkg/api/oas/oas_response_decoders_gen.go index e4b954c..0689ce0 100644 --- a/pkg/api/oas/oas_response_decoders_gen.go +++ b/pkg/api/oas/oas_response_decoders_gen.go @@ -82,6 +82,98 @@ func decodeGetApiInfoResponse(resp *http.Response) (res GetApiInfoOK, _ error) { return res, errors.Wrap(defRes, "error") } +func decodeGetStateResponse(resp *http.Response) (res *State, _ error) { + switch resp.StatusCode { + case 200: + // Code 200. + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response State + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } + return &response, nil + default: + return res, validate.InvalidContentType(ct) + } + } + // Convenient error response. + defRes, err := func() (res *ErrorStatusCode, err error) { + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response Error + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + return &ErrorStatusCode{ + StatusCode: resp.StatusCode, + Response: response, + }, nil + default: + return res, validate.InvalidContentType(ct) + } + }() + if err != nil { + return res, errors.Wrapf(err, "default (code %d)", resp.StatusCode) + } + return res, errors.Wrap(defRes, "error") +} + func decodeGetWalletInfoResponse(resp *http.Response) (res *WalletInfo, _ error) { switch resp.StatusCode { case 200: diff --git a/pkg/api/oas/oas_response_encoders_gen.go b/pkg/api/oas/oas_response_encoders_gen.go index 49d9800..d260106 100644 --- a/pkg/api/oas/oas_response_encoders_gen.go +++ b/pkg/api/oas/oas_response_encoders_gen.go @@ -27,6 +27,20 @@ func encodeGetApiInfoResponse(response GetApiInfoOK, w http.ResponseWriter, span return nil } +func encodeGetStateResponse(response *State, w http.ResponseWriter, span trace.Span) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + span.SetStatus(codes.Ok, http.StatusText(200)) + + e := new(jx.Encoder) + response.Encode(e) + if _, err := e.WriteTo(w); err != nil { + return errors.Wrap(err, "write") + } + + return nil +} + func encodeGetWalletInfoResponse(response *WalletInfo, w http.ResponseWriter, span trace.Span) error { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(200) diff --git a/pkg/api/oas/oas_router_gen.go b/pkg/api/oas/oas_router_gen.go index 1e8261b..81e86c1 100644 --- a/pkg/api/oas/oas_router_gen.go +++ b/pkg/api/oas/oas_router_gen.go @@ -68,6 +68,27 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } switch elem[0] { + case 's': // Prefix: "state" + origElem := elem + if l := len("state"); len(elem) >= l && elem[0:l] == "state" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch r.Method { + case "GET": + s.handleGetStateRequest([0]string{}, elemIsEscaped, w, r) + default: + s.notAllowed(w, r, "GET") + } + + return + } + + elem = origElem case 'w': // Prefix: "wallet" origElem := elem if l := len("wallet"); len(elem) >= l && elem[0:l] == "wallet" { @@ -238,6 +259,31 @@ func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) { } } switch elem[0] { + case 's': // Prefix: "state" + origElem := elem + if l := len("state"); len(elem) >= l && elem[0:l] == "state" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch method { + case "GET": + r.name = "GetState" + r.summary = "" + r.operationID = "getState" + r.pathPattern = "/state" + r.args = args + r.count = 0 + return r, true + default: + return + } + } + + elem = origElem case 'w': // Prefix: "wallet" origElem := elem if l := len("wallet"); len(elem) >= l && elem[0:l] == "wallet" { diff --git a/pkg/api/oas/oas_schemas_gen.go b/pkg/api/oas/oas_schemas_gen.go index f967b56..901944b 100644 --- a/pkg/api/oas/oas_schemas_gen.go +++ b/pkg/api/oas/oas_schemas_gen.go @@ -157,6 +157,32 @@ func (o OptWalletInfoCompressedInfo) Or(d WalletInfoCompressedInfo) WalletInfoCo return d } +// Ref: #/components/schemas/State +type State struct { + TotalWallets float64 `json:"total_wallets"` + MasterAddress string `json:"master_address"` +} + +// GetTotalWallets returns the value of TotalWallets. +func (s *State) GetTotalWallets() float64 { + return s.TotalWallets +} + +// GetMasterAddress returns the value of MasterAddress. +func (s *State) GetMasterAddress() string { + return s.MasterAddress +} + +// SetTotalWallets sets the value of TotalWallets. +func (s *State) SetTotalWallets(val float64) { + s.TotalWallets = val +} + +// SetMasterAddress sets the value of MasterAddress. +func (s *State) SetMasterAddress(val string) { + s.MasterAddress = val +} + // Ref: #/components/schemas/WalletInfo type WalletInfo struct { Owner string `json:"owner"` diff --git a/pkg/api/oas/oas_server_gen.go b/pkg/api/oas/oas_server_gen.go index c003bb2..67b9509 100644 --- a/pkg/api/oas/oas_server_gen.go +++ b/pkg/api/oas/oas_server_gen.go @@ -12,6 +12,10 @@ type Handler interface { // // GET / GetApiInfo(ctx context.Context) (GetApiInfoOK, error) + // GetState implements getState operation. + // + // GET /state + GetState(ctx context.Context) (*State, error) // GetWalletInfo implements getWalletInfo operation. // // GET /wallet/{address} diff --git a/pkg/api/oas/oas_unimplemented_gen.go b/pkg/api/oas/oas_unimplemented_gen.go index 1cf4230..3350e04 100644 --- a/pkg/api/oas/oas_unimplemented_gen.go +++ b/pkg/api/oas/oas_unimplemented_gen.go @@ -20,6 +20,13 @@ func (UnimplementedHandler) GetApiInfo(ctx context.Context) (r GetApiInfoOK, _ e return r, ht.ErrNotImplemented } +// GetState implements getState operation. +// +// GET /state +func (UnimplementedHandler) GetState(ctx context.Context) (r *State, _ error) { + return r, ht.ErrNotImplemented +} + // GetWalletInfo implements getWalletInfo operation. // // GET /wallet/{address} diff --git a/pkg/api/oas/oas_validators_gen.go b/pkg/api/oas/oas_validators_gen.go index 2d5194e..27e3e5b 100644 --- a/pkg/api/oas/oas_validators_gen.go +++ b/pkg/api/oas/oas_validators_gen.go @@ -8,6 +8,29 @@ import ( "github.com/ogen-go/ogen/validate" ) +func (s *State) Validate() error { + if s == nil { + return validate.ErrNilPointer + } + + var failures []validate.FieldError + if err := func() error { + if err := (validate.Float{}).Validate(float64(s.TotalWallets)); err != nil { + return errors.Wrap(err, "float") + } + return nil + }(); err != nil { + failures = append(failures, validate.FieldError{ + Name: "total_wallets", + Error: err, + }) + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + return nil +} + func (s *WalletList) Validate() error { if s == nil { return validate.ErrNilPointer diff --git a/pkg/prover/enumerate.go b/pkg/prover/enumerate.go index 520c2f2..b354071 100644 --- a/pkg/prover/enumerate.go +++ b/pkg/prover/enumerate.go @@ -254,3 +254,54 @@ func walk(startKey *boc.BitString, prefix *boc.BitString, cell *boc.Cell, count } return append(arrLeft, arrRight...), nil } + +func countLeaves(prefix *boc.BitString, cell *boc.Cell) (int, error) { + prefix.ResetCounter() + size := 267 - prefix.BitsAvailableForRead() + prefixSize, nextPrefix, err := readCommonPrefix(size, cell) + if err != nil { + return 0, err + } + currentPrefix, err := concatBitStrings(prefix, nextPrefix) + if err != nil { + return 0, err + } + if size == prefixSize { + return 1, nil + } + + left, err := addBit(currentPrefix, false) + if err != nil { + return 0, err + } + nxt, err := cell.NextRef() + if err != nil { + return 0, err // must be unreachable since leaf has been handled before + } + leftLeaves, err := countLeaves(left, nxt) + if err != nil { + right, err := addBit(currentPrefix, true) + if err != nil { + return 0, err + } + rightLeaves, err := countLeaves(right, nxt) + if err != nil { + return 0, err + } + + return rightLeaves, nil + } + right, err := addBit(currentPrefix, true) + if err != nil { + return 0, err + } + nxt, err = cell.NextRef() + if err != nil { + return 0, err + } + rightLeaves, err := countLeaves(right, nxt) + if err != nil { + return 0, err + } + return leftLeaves + rightLeaves, nil +} diff --git a/pkg/prover/enumerate_test.go b/pkg/prover/enumerate_test.go new file mode 100644 index 0000000..67a55d6 --- /dev/null +++ b/pkg/prover/enumerate_test.go @@ -0,0 +1,40 @@ +package prover + +import ( + "fmt" + "github.com/tonkeeper/tongo/boc" + "os" + "testing" +) + +func getAirdropRoot() (*boc.Cell, error) { + content, err := os.ReadFile("testdata/airdropData.boc") + if err != nil { + return nil, err + } + airdropCells, err := boc.DeserializeBoc(content) + if err != nil { + return nil, err + } + if len(airdropCells) != 1 { + return nil, fmt.Errorf("incorrect number of roots") + } + root := airdropCells[0] + + return root, nil +} + +func Test_countLeaves(t *testing.T) { + root, err := getAirdropRoot() + if err != nil { + t.Fatal(err) + } + prefix := boc.NewBitString(0) + cnt, err := countLeaves(&prefix, root) + if err != nil { + t.Fatal(err) + } + if cnt != 360 { + t.Errorf("incorrect number of leaves: got %v, want %v", cnt, 360) + } +} diff --git a/pkg/prover/prover.go b/pkg/prover/prover.go index 5d3c420..c6ae966 100644 --- a/pkg/prover/prover.go +++ b/pkg/prover/prover.go @@ -14,6 +14,15 @@ import ( "github.com/tonkeeper/claim-api-go/pkg/utils" ) +type AccountCountResponse struct { + AccountsCount int + Err error +} + +type AccountCountRequest struct { + ResponseCh chan<- AccountCountResponse +} + type ProofResponse struct { WalletAirdrop WalletAirdrop Err error @@ -110,6 +119,8 @@ func (p *Prover) Run(ctx context.Context) { p.processProofRequest(req) case EnumerateRequest: p.processEnumerateAccountsRequest(req) + case AccountCountRequest: + p.processCountAccounts(req) default: p.logger.Error("unexpected request type", zap.Any("reqAny", reqAny)) } @@ -117,6 +128,24 @@ func (p *Prover) Run(ctx context.Context) { } } +func (p *Prover) processCountAccounts(req AccountCountRequest) { + timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { + proverTimeHistogramVec.WithLabelValues("processCountAccounts").Observe(v) + })) + defer timer.ObserveDuration() + + jettonWalletCount, err := countAccounts(p.root) + if err != nil { + req.ResponseCh <- AccountCountResponse{ + Err: err, + } + return + } + req.ResponseCh <- AccountCountResponse{ + AccountsCount: jettonWalletCount, + } +} + func (p *Prover) processProofRequest(req ProofRequest) { timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { proverTimeHistogramVec.WithLabelValues("processProofRequest").Observe(v) @@ -185,6 +214,15 @@ func prove(accountID ton.AccountID, prover *boc.MerkleProver, root *boc.Cell) (W }, nil } +func countAccounts(root *boc.Cell) (int, error) { + prefix := boc.NewBitString(0) + cnt, err := countLeaves(&prefix, root) + if err != nil { + return 0, err + } + return cnt, nil +} + func enumerateAccounts(nextFrom ton.AccountID, root *boc.Cell, count int) ([]walletData, error) { root.ResetCounters() prefix := boc.NewBitString(0) From 7991a50a65e0a66aa768afa22769b4999c1e1d06 Mon Sep 17 00:00:00 2001 From: domikedos Date: Tue, 17 Jun 2025 17:03:16 +0600 Subject: [PATCH 2/2] count jetton wallets on start --- pkg/api/handler.go | 36 ++++++++++++++------------------- pkg/prover/prover.go | 47 +++++++++----------------------------------- 2 files changed, 24 insertions(+), 59 deletions(-) diff --git a/pkg/api/handler.go b/pkg/api/handler.go index 8196064..6f6107f 100644 --- a/pkg/api/handler.go +++ b/pkg/api/handler.go @@ -30,10 +30,11 @@ import ( type Handler struct { logger *zap.Logger - prover *prover.Prover - jettonMaster ton.AccountID - cli *liteapi.Client - config string + prover *prover.Prover + jettonMaster ton.AccountID + jettonWalletsQuantity int + cli *liteapi.Client + config string proofsCache utils.Cache[ton.AccountID, prover.WalletAirdrop] keyNotFoundCache utils.Cache[ton.AccountID, struct{}] @@ -78,11 +79,16 @@ func NewHandler(logger *zap.Logger, config Config) (*Handler, error) { if err != nil { return nil, fmt.Errorf("failed to create prover: %w", err) } + jettonWalletQuantity, err := p.CountJettonWallets() + if err != nil { + return nil, fmt.Errorf("failed to count jetton wallets") + } return &Handler{ prover: p, cli: cli, logger: logger, jettonMaster: config.JettonMaster, + jettonWalletsQuantity: jettonWalletQuantity, jettonMasterStateCache: map[ton.AccountID][2]string{}, config: blockchainConfig, proofsCache: utils.NewLRUCache[ton.AccountID, prover.WalletAirdrop](700_000, "proofs"), @@ -398,21 +404,9 @@ func (h *Handler) jettonMasterState(accountID ton.AccountID) (state [2]string, o } func (h *Handler) GetState(ctx context.Context) (*oas.State, error) { - ch := make(chan prover.AccountCountResponse, 1) - h.prover.Queue() <- prover.AccountCountRequest{ - ResponseCh: ch, - } - select { - case <-ctx.Done(): - return nil, BadRequest("timeout") - case resp := <-ch: - if resp.Err != nil { - return nil, InternalError(resp.Err) - } - masterAddress := h.jettonMaster.ToRaw() - return &oas.State{ - TotalWallets: float64(resp.AccountsCount), - MasterAddress: masterAddress, - }, nil - } + masterAddress := h.jettonMaster.ToRaw() + return &oas.State{ + TotalWallets: float64(h.jettonWalletsQuantity), + MasterAddress: masterAddress, + }, nil } diff --git a/pkg/prover/prover.go b/pkg/prover/prover.go index c6ae966..b524adc 100644 --- a/pkg/prover/prover.go +++ b/pkg/prover/prover.go @@ -14,15 +14,6 @@ import ( "github.com/tonkeeper/claim-api-go/pkg/utils" ) -type AccountCountResponse struct { - AccountsCount int - Err error -} - -type AccountCountRequest struct { - ResponseCh chan<- AccountCountResponse -} - type ProofResponse struct { WalletAirdrop WalletAirdrop Err error @@ -107,6 +98,15 @@ func (p *Prover) MerkleRoot() tlb.Bits256 { return p.merkleRoot } +func (p *Prover) CountJettonWallets() (int, error) { + prefix := boc.NewBitString(0) + cnt, err := countLeaves(&prefix, p.root) + if err != nil { + return 0, err + } + return cnt, nil +} + func (p *Prover) Run(ctx context.Context) { go p.queue.Run(ctx) for { @@ -119,8 +119,6 @@ func (p *Prover) Run(ctx context.Context) { p.processProofRequest(req) case EnumerateRequest: p.processEnumerateAccountsRequest(req) - case AccountCountRequest: - p.processCountAccounts(req) default: p.logger.Error("unexpected request type", zap.Any("reqAny", reqAny)) } @@ -128,24 +126,6 @@ func (p *Prover) Run(ctx context.Context) { } } -func (p *Prover) processCountAccounts(req AccountCountRequest) { - timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - proverTimeHistogramVec.WithLabelValues("processCountAccounts").Observe(v) - })) - defer timer.ObserveDuration() - - jettonWalletCount, err := countAccounts(p.root) - if err != nil { - req.ResponseCh <- AccountCountResponse{ - Err: err, - } - return - } - req.ResponseCh <- AccountCountResponse{ - AccountsCount: jettonWalletCount, - } -} - func (p *Prover) processProofRequest(req ProofRequest) { timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { proverTimeHistogramVec.WithLabelValues("processProofRequest").Observe(v) @@ -214,15 +194,6 @@ func prove(accountID ton.AccountID, prover *boc.MerkleProver, root *boc.Cell) (W }, nil } -func countAccounts(root *boc.Cell) (int, error) { - prefix := boc.NewBitString(0) - cnt, err := countLeaves(&prefix, root) - if err != nil { - return 0, err - } - return cnt, nil -} - func enumerateAccounts(nextFrom ton.AccountID, root *boc.Cell, count int) ([]walletData, error) { root.ResetCounters() prefix := boc.NewBitString(0)