diff --git a/README.md b/README.md index e0a630a..5274fbb 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,8 @@ Whole Pipeline (Enc+Dec) | FastSAM | Mobile Parameters | 68M | 9.66M Speed | 64ms |12ms -:muscle: **Does MobileSAM aign better with the original SAM than FastSAM? Yes!** -FastSAM is suggested to work with multiple points, thus we compare the mIoU with two prompt points (with different pixel distances) and show the resutls as follows. Higher mIoU indicates higher alignment. +:muscle: **Does MobileSAM align better with the original SAM than FastSAM? Yes!** +FastSAM is suggested to work with multiple points, thus we compare the mIoU with two prompt points (with different pixel distances) and show the results as follows. Higher mIoU indicates higher alignment. mIoU | FastSAM | MobileSAM :-----------------------------------------:|:---------|:-----: 100 | 0.27 | 0.73 diff --git a/app/apiserver.py b/app/apiserver.py new file mode 100644 index 0000000..1d8363d --- /dev/null +++ b/app/apiserver.py @@ -0,0 +1,54 @@ +import base64 +import PIL +import re +from io import BytesIO + +from flask import Flask, jsonify, request +from flask_cors import CORS +from app import get_points_with_draw, segment_with_points + +app = Flask(__name__) +CORS(app) + + +@app.route('/food_segmentation', methods=['POST']) +def food_segmentation(): + imageB64 = re.sub('^data:image/.+;base64,', '', request.json['imageB64']) + x1 = request.json['x1'] + y1 = request.json['y1'] + x2 = request.json['x2'] + y2 = request.json['y2'] + + if x1 == None or y1 == None or x2 == None or y2 == None: + result = { + "imageB64": request.json['imageB64'] + } + return jsonify(result) + image = PIL.Image.open(BytesIO(base64.b64decode(imageB64))) + + # print(image) + + get_points_with_draw(None, "Add Mask", x1+50, y1+50) + get_points_with_draw(None, "Add Mask", x2-50, y1+50) + get_points_with_draw(None, "Add Mask", x1+50, y2-50) + get_points_with_draw(None, "Add Mask", x2-50, y2-50) + + fig, _ = segment_with_points(image) + + img_byte_array = BytesIO() + fig.save(img_byte_array, format="PNG") + img_byte_array = img_byte_array.getvalue() + + # Bytes를 base64로 인코딩 + base64_encoded = "data:image/png;base64," + base64.b64encode(img_byte_array).decode("utf-8") + + result = { + "imageB64": base64_encoded + } + return jsonify(result) + + +if __name__ == '__main__': + ip_address = "127.0.0.1" + port_number = 8001 + app.run(ip_address, port=int(port_number), debug=True) diff --git a/app/app.py b/app/app.py index 71316fb..b62abe3 100644 --- a/app/app.py +++ b/app/app.py @@ -155,11 +155,11 @@ def segment_with_points( return fig, image -def get_points_with_draw(image, label, evt: gr.SelectData): +def get_points_with_draw(image, label, x, y): global global_points global global_point_label - x, y = evt.index[0], evt.index[1] + # x, y = evt.index[0], evt.index[1] point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else ( 255, 0, @@ -171,11 +171,11 @@ def get_points_with_draw(image, label, evt: gr.SelectData): print(x, y, label == "Add Mask") # 创建一个可以在图像上绘图的对象 - draw = ImageDraw.Draw(image) - draw.ellipse( - [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], - fill=point_color, - ) + # draw = ImageDraw.Draw(image) + # draw.ellipse( + # [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], + # fill=point_color, + # ) return image @@ -199,131 +199,131 @@ def get_points_with_draw(image, label, evt: gr.SelectData): info="Our model was trained on a size of 1024", ) -with gr.Blocks(css=css, title="Faster Segment Anything(MobileSAM)") as demo: - with gr.Row(): - with gr.Column(scale=1): - # Title - gr.Markdown(title) - - # with gr.Tab("Everything mode"): - # # Images - # with gr.Row(variant="panel"): - # with gr.Column(scale=1): - # cond_img_e.render() - # - # with gr.Column(scale=1): - # segm_img_e.render() - # - # # Submit & Clear - # with gr.Row(): - # with gr.Column(): - # input_size_slider.render() - # - # with gr.Row(): - # contour_check = gr.Checkbox( - # value=True, - # label="withContours", - # info="draw the edges of the masks", - # ) - # - # with gr.Column(): - # segment_btn_e = gr.Button( - # "Segment Everything", variant="primary" - # ) - # clear_btn_e = gr.Button("Clear", variant="secondary") - # - # gr.Markdown("Try some of the examples below ⬇️") - # gr.Examples( - # examples=examples, - # inputs=[cond_img_e], - # outputs=segm_img_e, - # fn=segment_everything, - # cache_examples=True, - # examples_per_page=4, - # ) - # - # with gr.Column(): - # with gr.Accordion("Advanced options", open=False): - # # text_box = gr.Textbox(label="text prompt") - # with gr.Row(): - # mor_check = gr.Checkbox( - # value=False, - # label="better_visual_quality", - # info="better quality using morphologyEx", - # ) - # with gr.Column(): - # retina_check = gr.Checkbox( - # value=True, - # label="use_retina", - # info="draw high-resolution segmentation masks", - # ) - # # Description - # gr.Markdown(description_e) - # - with gr.Tab("Point mode"): - # Images - with gr.Row(variant="panel"): - with gr.Column(scale=1): - cond_img_p.render() - - with gr.Column(scale=1): - segm_img_p.render() - - # Submit & Clear - with gr.Row(): - with gr.Column(): - with gr.Row(): - add_or_remove = gr.Radio( - ["Add Mask", "Remove Area"], - value="Add Mask", - ) - - with gr.Column(): - segment_btn_p = gr.Button( - "Start segmenting!", variant="primary" - ) - clear_btn_p = gr.Button("Restart", variant="secondary") - - gr.Markdown("Try some of the examples below ⬇️") - gr.Examples( - examples=examples, - inputs=[cond_img_p], - # outputs=segm_img_p, - # fn=segment_with_points, - # cache_examples=True, - examples_per_page=4, - ) - - with gr.Column(): - # Description - gr.Markdown(description_p) - - cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p) - - # segment_btn_e.click( - # segment_everything, - # inputs=[ - # cond_img_e, - # input_size_slider, - # mor_check, - # contour_check, - # retina_check, - # ], - # outputs=segm_img_e, - # ) - - segment_btn_p.click( - segment_with_points, inputs=[cond_img_p], outputs=[segm_img_p, cond_img_p] - ) - - def clear(): - return None, None - - def clear_text(): - return None, None, None - - # clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e]) - clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p]) - -demo.queue() -demo.launch() +# with gr.Blocks(css=css, title="Faster Segment Anything(MobileSAM)") as demo: +# with gr.Row(): +# with gr.Column(scale=1): +# # Title +# gr.Markdown(title) +# +# # with gr.Tab("Everything mode"): +# # # Images +# # with gr.Row(variant="panel"): +# # with gr.Column(scale=1): +# # cond_img_e.render() +# # +# # with gr.Column(scale=1): +# # segm_img_e.render() +# # +# # # Submit & Clear +# # with gr.Row(): +# # with gr.Column(): +# # input_size_slider.render() +# # +# # with gr.Row(): +# # contour_check = gr.Checkbox( +# # value=True, +# # label="withContours", +# # info="draw the edges of the masks", +# # ) +# # +# # with gr.Column(): +# # segment_btn_e = gr.Button( +# # "Segment Everything", variant="primary" +# # ) +# # clear_btn_e = gr.Button("Clear", variant="secondary") +# # +# # gr.Markdown("Try some of the examples below ⬇️") +# # gr.Examples( +# # examples=examples, +# # inputs=[cond_img_e], +# # outputs=segm_img_e, +# # fn=segment_everything, +# # cache_examples=True, +# # examples_per_page=4, +# # ) +# # +# # with gr.Column(): +# # with gr.Accordion("Advanced options", open=False): +# # # text_box = gr.Textbox(label="text prompt") +# # with gr.Row(): +# # mor_check = gr.Checkbox( +# # value=False, +# # label="better_visual_quality", +# # info="better quality using morphologyEx", +# # ) +# # with gr.Column(): +# # retina_check = gr.Checkbox( +# # value=True, +# # label="use_retina", +# # info="draw high-resolution segmentation masks", +# # ) +# # # Description +# # gr.Markdown(description_e) +# # +# with gr.Tab("Point mode"): +# # Images +# with gr.Row(variant="panel"): +# with gr.Column(scale=1): +# cond_img_p.render() +# +# with gr.Column(scale=1): +# segm_img_p.render() +# +# # Submit & Clear +# with gr.Row(): +# with gr.Column(): +# with gr.Row(): +# add_or_remove = gr.Radio( +# ["Add Mask", "Remove Area"], +# value="Add Mask", +# ) +# +# with gr.Column(): +# segment_btn_p = gr.Button( +# "Start segmenting!", variant="primary" +# ) +# clear_btn_p = gr.Button("Restart", variant="secondary") +# +# gr.Markdown("Try some of the examples below ⬇️") +# gr.Examples( +# examples=examples, +# inputs=[cond_img_p], +# # outputs=segm_img_p, +# # fn=segment_with_points, +# # cache_examples=True, +# examples_per_page=4, +# ) +# +# with gr.Column(): +# # Description +# gr.Markdown(description_p) +# +# cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p) +# +# # segment_btn_e.click( +# # segment_everything, +# # inputs=[ +# # cond_img_e, +# # input_size_slider, +# # mor_check, +# # contour_check, +# # retina_check, +# # ], +# # outputs=segm_img_e, +# # ) +# +# segment_btn_p.click( +# segment_with_points, inputs=[cond_img_p], outputs=[segm_img_p, cond_img_p] +# ) +# +# def clear(): +# return None, None +# +# def clear_text(): +# return None, None, None +# +# # clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e]) +# clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p]) + +# demo.queue() +# demo.launch() diff --git a/app/requirements.txt b/app/requirements.txt index 79d0fca..213cade 100755 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -3,3 +3,5 @@ torchvision timm opencv-python git+https://github.com/dhkim2810/MobileSAM.git +flask +flask_cors \ No newline at end of file