-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
72 lines (50 loc) · 2.18 KB
/
server.py
File metadata and controls
72 lines (50 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from proof_search import generate_tac, BestFirstSearchProver
from flask import Flask, request, jsonify, render_template
app = Flask(__name__)
@app.route('/project/ds-prover/')
def hello():
return render_template("index.html")
@app.route('/project/ds-prover/prove/', methods=['POST'])
def prove():
if not request.json:
return jsonify({'error': 'Request must be in JSON format'}), 400
data = request.get_json()
if 'statement' not in data:
return jsonify({'error': 'Missing parameter: statement'}), 400
elif type(data['statement']) is not str:
return jsonify({'error': 'statement must be a string'}), 400
bfs = BestFirstSearchProver(timeout=300, aug = data["model"] == "augmented", dynamic = data["sampling"] == "dynamic")
result = bfs.search(data["statement"].strip())
if result['proved']:
if result['proof'] == []:
return {'proved': 2, "statement": result['theorem'], "time": round(result['total_time'], 1)}
else:
imports = "import all\n"
return {'proved': 1, "statement": imports + result['theorem'], "proof": result['proof'], "time": round(result['total_time'], 1)}
else:
return {'proved': 0, 'error': result['error']}
@app.route('/project/ds-prover/generate/', methods=['POST'])
def generate():
if not request.json:
return jsonify({'error': 'Request must be in JSON format'}), 400
data = request.get_json()
if 'goal' not in data:
return jsonify({'error': 'Missing parameter: goal'}), 400
elif type(data['goal']) is not str:
return jsonify({'error': 'goal must be a string'}), 400
count = 32
if type(data['num_samples']) == int:
if data['num_samples'] < 2:
count = 2
elif data['num_samples'] > 100:
count = 100
else:
count = data['num_samples']
result = generate_tac(ts = data['goal'],
count = count,
aug = data["model"] == "augmented")
torch.cuda.empty_cache()
return result
if __name__ == '__main__':
app.run(debug=True, port=5001, host='0.0.0.0')