diff --git a/.gitignore b/.gitignore index c6b2baf..5d5d2bc 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ train_dataset/** input/** submission.txt submit/** +download.sh # model config # configs/** diff --git a/demo/README.md b/demo/README.md new file mode 100644 index 0000000..ea11282 --- /dev/null +++ b/demo/README.md @@ -0,0 +1,26 @@ +# 부스트캠프 수식인식기 크롬익스텐션 + +## Getting Started +![image](demo.gif) +### Prerequisites +부스트캠프 수식인식기 모델이 있어야 합니다. +### Install + +#### Server (server 폴더) +현재 서버코드는 모델관련된 코드가 적용되어 있지 않습니다. 이 부분은 대회에서 사용하셨던 inference코드를 수정하여서 적용해주세요 +``` +pip install -r requirements.txt +python main.py +``` + +기본적으로 `a ^ { 2 } + b ^ { 2 } = c ^ { 2 }` 가 반환되도록 설정되었습니다. + +#### Chrome Extension (boost_susik 폴더) +1. 위에서 서버를 시작한 다음 그에 대응하는 URL을 `content.js` 상단에 변수로 넣어주세요! + ``` + const SERVERL_URL='http://:/susik_recognize'; + ``` +2. 크롬 주소창에 `chrome://extensions/` 다음 주소를 입력하고, `boost_susik`이라는 폴더를 `압축해제된 확장 프로그램을 로드합니다.` 버튼을 클릭하여 로드해주세요. + + + diff --git a/demo/boost_susik/background.js b/demo/boost_susik/background.js new file mode 100644 index 0000000..b0058bd --- /dev/null +++ b/demo/boost_susik/background.js @@ -0,0 +1,11 @@ +chrome.runtime.onMessage.addListener(function(msg, sender, sendResponse) { + //스크린 캡쳐를 하는 코드 + chrome.tabs.captureVisibleTab(null, { + format : "png", + quality : 100 + }, function(data) { + sendResponse(data); + }); + return true; +}); + diff --git a/demo/boost_susik/content.js b/demo/boost_susik/content.js new file mode 100644 index 0000000..abd7cbf --- /dev/null +++ b/demo/boost_susik/content.js @@ -0,0 +1,182 @@ +const SERVERL_URL='http://118.67.134.149:6012/susik_recognize'; + +fetch(chrome.runtime.getURL('/template.html')).then(r => r.text()).then(html => { + document.body.insertAdjacentHTML('beforeend', html); + }); + +let selectBoxX = -1 +let selectBoxY = -1 + + +document.body.addEventListener('click', e=>{ + //하단에 박스 닫는 버튼 클릭시 + if (e.target.id =='susik-close'){ + let showBox = document.querySelector('#show-box'); + let susikBox = document.querySelector('#susik-box'); + let susikOutput = document.querySelector('#susik-output'); + let susikImage = document.querySelector('#susik-image'); + let susikLatex = document.querySelector('#susik-output-latex'); + + //관련된 UI들 초기화 + showBox.style.display = 'none'; + susikBox.style.width='0px'; + susikBox.style.height='0px'; + susikBox.style.top="-1px"; + susikBox.style.left="-1px"; + susikOutput.value = ''; + susikImage.src = ''; + susikLatex.src = ''; + } + //fix 버튼 클릭시 + if (e.target.id == 'fix-text'){ + //latex image를 변경된 text에 맞추어 갱신 + document.querySelector('#susik-output-latex').src = "http://latex.codecogs.com/gif.latex?" + document.querySelector('#susik-output').value + } + //copy 버튼 클릭시 + if (e.target.id == 'copy-text'){ + //text copy + copyText() + } +}); + + +document.body.addEventListener('mousedown', e => { + var isActivated = document.querySelector('#susik-box').getAttribute("data-activate"); + let susikBox = document.querySelector('#susik-box'); + susikBox.style.display='block'; + + if(isActivated=="true"){ + let x = e.clientX; + let y = e.clientY; + //Select Box의 시작점을 현재 마우스 클릭 지점으로 등록 + selectBoxX = x; + selectBoxY = y; + + //Susik Box이 위치와 사이즈를 현재 지점에서 초기화 + susikBox.style.top = y+'px'; + susikBox.style.left = x+'px'; + susikBox.style.width='0px'; + susikBox.style.height='0px'; + } +}); + + +//캡쳐가 준비된 상태에서 (마우스 클릭이 된 상태) 드래그시 박스 사이즈 업데이트 +document.body.addEventListener('mousemove', e => { + try{ + var susikBox = document.querySelector('#susik-box'); + var isActivated = susikBox.getAttribute("data-activate"); + }catch(e){ + return; + } + + //팝업에서 Start 버튼을 클릭하고, select 박스의 값이 초기값이 아닌 상태인 경우 시작 + if(isActivated=="true" && (selectBoxX != -1 && selectBoxY != -1)){ + let x = e.clientX; + let y = e.clientY; + + //Select 박스(susik-box)의 가로 세로를 마우스 이동에 맞게 변경 + width = x-selectBoxX; + height = y-selectBoxY; + + susikBox.style.width = width+'px'; + susikBox.style.height = height+'px'; + + } +}); + +// 마우스 드래그가 끝난 시점 (드랍) +document.body.addEventListener('mouseup', e => { + let susikBox = document.querySelector('#susik-box'); + let isActivated = susikBox.getAttribute("data-activate"); + + //만약 팝업의 start 버튼을 클릭한 후의, 그냥 취소 + if(isActivated=="false"){ + return ; + } + + // 다음 이벤트가 ?? + susikBox.setAttribute("data-activate", "false"); + + + + let x = parseInt(selectBoxX); + let y = parseInt(selectBoxY); + let w = parseInt(susikBox.style.width); + let h = parseInt(susikBox.style.height); + + //캡쳐 과정이 끝났으므로, susik-box 관련된 내용 초기화 + selectBoxX = -1; + selectBoxY = -1; + + + susikBox.style.display='none'; + susikBox.style.width='0px'; + susikBox.style.height='0px'; + susikBox.style.top="-1px"; + susikBox.style.left="-1px"; + + //Overaly 화면 안보이게 초기화 + document.querySelector('#overlay').style.display='none'; + //마우스 Cursor도 원래 커서로 초기화 + document.body.style.cursor = "default"; + + + //200ms 정도의 시간차를 두고 서버로 현재 캡쳐된 이미지를 전송 + //시간차를 안두면, 박스와 오버레이 화면이 같이 넘어갈 수 있음 + setTimeout(function(){ + chrome.runtime.sendMessage({text:"hello"}, function(response) { + var img=new Image(); + img.crossOrigin='anonymous'; + img.onload=start; + img.src=response; + + function start(){ + //화면 비율에 따라 원래 설정한 좌표 및 길이와 캡쳐본에서의 좌표와 길이가 다를 수가 있어서, 그에 대응하는 비율을 곱해줌 + ratio = img.width/window.innerWidth; + + + var croppedURL=cropPlusExport(img,x*ratio,y*ratio,w*ratio,h*ratio); + var cropImg=new Image(); + cropImg.src=croppedURL; + document.querySelector('#susik-image').src = croppedURL; + fetch(SERVERL_URL, { + method: 'POST', + body: JSON.stringify({"image":croppedURL}), // data can be `string` or {object}! + headers:{ + 'Content-Type': 'application/json' + } + }).then(res => res.json()) + .then(response => { + document.querySelector('#susik-output').value = response['result']; + document.querySelector('#susik-output-latex').src = "http://latex.codecogs.com/gif.latex?" + response['result']; + }); + } + + }); + },200); + + +}); + +//전체 스크린샷을 crop하는 함수 +function cropPlusExport(img,cropX,cropY,cropWidth,cropHeight){ + + + var canvas1=document.createElement('canvas'); + var ctx1=canvas1.getContext('2d'); + canvas1.width=cropWidth; + canvas1.height=cropHeight; + + ctx1.drawImage(img,cropX,cropY,cropWidth,cropHeight,0,0,cropWidth,cropHeight); + + return(canvas1.toDataURL()); + } + +//textbox의 내용을 copy하는 함수 +function copyText() { + var obj = document.getElementById("susik-output"); + obj.select(); //인풋 컨트롤의 내용 전체 선택 + document.execCommand("copy"); //복사 + obj.setSelectionRange(0, 0); //선택영역 초기화 + } diff --git a/demo/boost_susik/images/boostsusik128.png b/demo/boost_susik/images/boostsusik128.png new file mode 100644 index 0000000..d837671 Binary files /dev/null and b/demo/boost_susik/images/boostsusik128.png differ diff --git a/demo/boost_susik/images/boostsusik16.png b/demo/boost_susik/images/boostsusik16.png new file mode 100644 index 0000000..8515d51 Binary files /dev/null and b/demo/boost_susik/images/boostsusik16.png differ diff --git a/demo/boost_susik/images/boostsusik32.png b/demo/boost_susik/images/boostsusik32.png new file mode 100644 index 0000000..0976ea8 Binary files /dev/null and b/demo/boost_susik/images/boostsusik32.png differ diff --git a/demo/boost_susik/images/boostsusik48.png b/demo/boost_susik/images/boostsusik48.png new file mode 100644 index 0000000..7946d25 Binary files /dev/null and b/demo/boost_susik/images/boostsusik48.png differ diff --git a/demo/boost_susik/manifest.json b/demo/boost_susik/manifest.json new file mode 100644 index 0000000..6e1b8c7 --- /dev/null +++ b/demo/boost_susik/manifest.json @@ -0,0 +1,39 @@ +{ + "name": "Boost Susik", + "description": "Build an Extension!", + "version": "1.0", + "manifest_version": 3, + "background": { + "service_worker": "background.js" + }, + "permissions": ["storage", "activeTab", "scripting"], + "content_scripts": [ + { + "matches": [""], + "js": ["content.js"] + } + ], + "action": { + "default_popup": "popup.html", + "default_icon": { + "16": "/images/boostsusik16.png", + "32": "/images/boostsusik32.png", + "48": "/images/boostsusik48.png", + "128": "/images/boostsusik128.png" + } + }, + "icons": { + "16": "/images/boostsusik16.png", + "32": "/images/boostsusik32.png", + "48": "/images/boostsusik48.png", + "128": "/images/boostsusik128.png" + }, + "web_accessible_resources": [ + { + "resources": ["template.html"], + "matches": [""] + + } + ] + + } \ No newline at end of file diff --git a/demo/boost_susik/popup.html b/demo/boost_susik/popup.html new file mode 100644 index 0000000..e72f3f1 --- /dev/null +++ b/demo/boost_susik/popup.html @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/demo/boost_susik/popup.js b/demo/boost_susik/popup.js new file mode 100644 index 0000000..0703b00 --- /dev/null +++ b/demo/boost_susik/popup.js @@ -0,0 +1,24 @@ +// Initialize button with user's preferred color +let selectSusik = document.getElementById("select-susik"); + + + +// When the button is clicked, inject setPageBackgroundColor into current page +selectSusik.addEventListener("click", async () => { + let [tab] = await chrome.tabs.query({ active: true, currentWindow: true }); + window.close(); + chrome.scripting.executeScript({ + target: { tabId: tab.id }, + function: setSusikBox, + }); + }); + + + function setSusikBox() { + document.body.style.cursor = "cell"; + document.querySelector('#susik-box').style.backgroundColor = 'none'; + document.querySelector('#susik-box').setAttribute("data-activate", "true"); + document.querySelector('#overlay').style.display = 'block'; + document.querySelector('#show-box').style.display = 'block'; + + } \ No newline at end of file diff --git a/demo/boost_susik/template.html b/demo/boost_susik/template.html new file mode 100644 index 0000000..8279e83 --- /dev/null +++ b/demo/boost_susik/template.html @@ -0,0 +1,58 @@ + + diff --git a/demo/demo.gif b/demo/demo.gif new file mode 100644 index 0000000..a97bf52 Binary files /dev/null and b/demo/demo.gif differ diff --git a/demo/server/inference.py b/demo/server/inference.py new file mode 100644 index 0000000..304c537 --- /dev/null +++ b/demo/server/inference.py @@ -0,0 +1,65 @@ +import numpy as np +import torch +from torch.utils.data import DataLoader +from PIL import Image + +import sys +sys.path.append("/opt/ml/code") + +from utils.flags import Flags +from utils.utils import id_to_string_for_serve +from data.augmentations import get_valid_transforms +from postprocessing.postprocessing import get_decoding_manager +from networks.EfficientSATRN import EfficientSATRN_for_serve +# from networks.LiteSATRN import LiteSATRN + + +def prepare_model(): + checkpoint_path = "/opt/ml/code/models/satrn-fold-2-0.8171.pth" # for EfficientSATRN + # checkpoint_path = "/opt/ml/code/models/LiteSATRN_best_model.pth" # for LiteSATRN + if torch.cuda.is_available(): + device = torch.device("cuda") + checkpoint = torch.load(checkpoint_path) + else: + device = torch.device("cpu") + checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + + options = Flags(checkpoint["configs"]).get() + + decoding_manager = False + manager = ( + get_decoding_manager( + tokens_path="/opt/ml/input/data/train_dataset/tokens.txt", batch_size=batch_size + ) + if decoding_manager + else None + ) + + model = EfficientSATRN_for_serve(options, checkpoint, decoding_manager).to(device) # for EfficientSATRN + # model = LiteSATRN(options, checkpoint, decoding_manager).to(device) # for LiteSATRN + model.eval() + + return model, device, checkpoint, options + + +def inference(model, device, checkpoint, options, image): + transforms = get_valid_transforms( + height=options.input_size.height, width=options.input_size.width + ) + # image = Image.open(image) # for test + image = image.convert("RGB") + w, h = image.size + if h / w > 2: + image = image.rotate(90, expand=True) + image = np.array(image) + image = transforms(image=image)["image"] + + with torch.no_grad(): + input = image.float().to(device) + output = model(input) + decoded_values = output.transpose(1, 2) # [B, VOCAB_SIZE, MAX_LEN] + _, sequence = torch.topk(decoded_values, 1, dim=1) # sequence: [B, 1, MAX_LEN] + sequence = sequence.squeeze(1) + sequence_str = id_to_string_for_serve(sequence, checkpoint, do_eval=1) + + return sequence_str diff --git a/demo/server/main.py b/demo/server/main.py new file mode 100644 index 0000000..1ee0567 --- /dev/null +++ b/demo/server/main.py @@ -0,0 +1,40 @@ +from flask import Flask, request +from PIL import Image +import base64 +import json +import io +from inference import prepare_model, inference +from flask_cors import CORS, cross_origin + +import time + + +app = Flask(__name__) +app.config['CORS_HEADERS'] = 'Content-Type' +cors = CORS(app) + +start = time.time() +model, device, checkpoint, options = prepare_model() +print(f"Model loading time : {time.time() - start}") + +@app.route("/susik_recognize", methods=["POST"]) +@cross_origin() +def susik_recognize(): + image = request.json['image'] + + image = image.split(",")[1] + image = base64.b64decode(image) + image = io.BytesIO(image) + image = Image.open(image) + + start = time.time() + output = inference(model, device, checkpoint, options, image) + print(f"Inference time : {time.time() - start}") + + data = {'result':output} + + return json.dumps(data) + + +if __name__ == '__main__': + app.run(host="0.0.0.0", port=6006, debug=True) \ No newline at end of file diff --git a/demo/server/requirements.txt b/demo/server/requirements.txt new file mode 100644 index 0000000..f35416d --- /dev/null +++ b/demo/server/requirements.txt @@ -0,0 +1,3 @@ +flask +Pillow +flask_cors \ No newline at end of file diff --git a/networks/EfficientSATRN.py b/networks/EfficientSATRN.py index afc4558..058c55c 100644 --- a/networks/EfficientSATRN.py +++ b/networks/EfficientSATRN.py @@ -554,7 +554,7 @@ def forward( target, _out = self.manager.sift(_out[:, -1:, :]) else: target = torch.argmax(_out[:, -1:, :], dim=-1) # [b, 1] - target = target.squeeze() # [b] + target = target.squeeze(-1) # [b] NOTE () -> (-1) for serve out.append(_out) out = torch.stack(out, dim=1).to(device) # [b, max length, 1, class length] @@ -565,6 +565,46 @@ def forward( return out + def forward_serve(self, src, batch_max_length=50): # NOTE decoder forward for serve + out = [] + num_steps = batch_max_length - 1 + target = ( + torch.LongTensor(src.size(0)).fill_(self.st_id).to(device) + ) # [START] token + features = [None] * self.layer_num + + if self.manager: + self.manager.reset(sequence_length=num_steps) + + for t in range(num_steps): + target = target.unsqueeze(1) + tgt = self.text_embedding(target) + tgt = self.pos_encoder(tgt, point=t) + tgt_mask = self.order_mask(t + 1) + tgt_mask = tgt_mask[:, -1].unsqueeze(1) # [1, (l+1)] + for l, layer in enumerate(self.attention_layers): + tgt = layer(tgt, features[l], src, tgt_mask) + features[l] = ( + tgt if features[l] == None else torch.cat([features[l], tgt], 1) + ) + + _out = self.generator(tgt) # [b, 1, c] + + if self.manager: + target, _out = self.manager.sift(_out[:, -1:, :]) + else: + target = torch.argmax(_out[:, -1:, :], dim=-1) # [b, 1] + target = target.squeeze(0) # [b] + out.append(_out) + + out = torch.stack(out, dim=1).to(device) # [b, max length, 1, class length] + out = out.squeeze(2) # [b, max length, class length] + + if self.manager: + self.manager.reset() + + return out + class SATRNDecoder_soft(nn.Module): """NOTE: 그리디 디코딩 앙상블에 활용""" @@ -1111,3 +1151,40 @@ def beam_search( outputs = torch.tensor(outputs) return outputs + + +class EfficientSATRN_for_serve(nn.Module): # NOTE EfficientSATRN for serve + def __init__(self, FLAGS, checkpoint=None, decoding_manager=None): + super(EfficientSATRN_for_serve, self).__init__() + self.encoder = SATRNEncoder( + input_height=FLAGS.input_size.height, + input_width=FLAGS.input_size.width, + input_channel=FLAGS.data.rgb, + hidden_size=FLAGS.SATRN.encoder.hidden_dim, + filter_size=FLAGS.SATRN.encoder.filter_dim, + head_num=FLAGS.SATRN.encoder.head_num, + layer_num=FLAGS.SATRN.encoder.layer_num, + dropout_rate=FLAGS.dropout_rate, + ) + + self.decoder = SATRNDecoder( + num_classes=len(checkpoint["id_to_token"]), + src_dim=FLAGS.SATRN.decoder.src_dim, + hidden_dim=FLAGS.SATRN.decoder.hidden_dim, + filter_dim=FLAGS.SATRN.decoder.filter_dim, + head_num=FLAGS.SATRN.decoder.head_num, + dropout_rate=FLAGS.dropout_rate, + pad_id=checkpoint["token_to_id"][PAD], + st_id=checkpoint["token_to_id"][START], + layer_num=FLAGS.SATRN.decoder.layer_num, + decoding_manager=decoding_manager, + ) + + if checkpoint["model"]: + self.load_state_dict(checkpoint["model"]) + + def forward(self, input): + enc_result = self.encoder(input.unsqueeze(0)) + dec_result = self.decoder.forward_serve(src=enc_result, batch_max_length=50) + + return dec_result diff --git a/utils/utils.py b/utils/utils.py index e66cee0..e17c300 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -164,6 +164,39 @@ def id_to_string(tokens, data_loader, do_eval=0): return result +def id_to_string_for_serve(tokens, checkpoint, do_eval=0): # NOTE serve용 id_to_string + """디코더를 통해 얻은 추론 결과를 문자열로 구성된 수식으로 복원하는 함수""" + result = [] + if do_eval: + eos_id = checkpoint["token_to_id"][""] + special_ids = set( + [ + checkpoint["token_to_id"][""], + checkpoint["token_to_id"][""], + eos_id, + ] + ) + + for example in tokens: + string = "" + if do_eval: + for token in example: + token = token.item() + if token not in special_ids: + if token != -1: + string += checkpoint["id_to_token"][token] + " " + elif token == eos_id: + break + else: + for token in example: + token = token.item() + if token != -1: + string += checkpoint["id_to_token"][token] + " " + + result.append(string) + return result + + def set_seed(seed: int = 21): """시드값을 고정하는 함수. 실험 재현을 위해 사용""" torch.manual_seed(seed)