-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatchProcess.py
More file actions
101 lines (79 loc) · 3.49 KB
/
batchProcess.py
File metadata and controls
101 lines (79 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Batch processing endpoint.
Accepts a CSV file with ``We`` and ``Oh`` columns (first row = header),
applies the same SL-theory input validation and ``predBeta`` model used by
``/regime`` for each row, and returns a CSV with ``beta`` filled in as the
third column.
No new runtime dependencies - uses Python stdlib ``csv`` only.
"""
import csv
import io
from pathlib import Path
from flask import Blueprint, request, jsonify, Response
from SLtheory_prediction import load_model_payload, predict_beta_from_payload
from theory_ranges import validate_theory_inputs
batch_bp = Blueprint("batch", __name__)
_MODEL_PAYLOAD = load_model_payload(
Path(__file__).resolve().with_name("SLtheory_model.json")
)
MAX_BATCH_UPLOAD_BYTES = 1024 * 1024
def _format_row_error(line_num, message, we_raw, oh_raw):
return f"row {line_num}: {message} (We={we_raw!r}, Oh={oh_raw!r})"
@batch_bp.route("/batch", methods=["POST"])
def batch_process():
"""Process a CSV of (We, Oh) pairs and return a CSV with beta filled in."""
if "file" not in request.files:
return jsonify({"error": "No file attached. Expected multipart field 'file'."}), 400
uploaded = request.files["file"]
filename = uploaded.filename or ""
if not filename.lower().endswith(".csv"):
return jsonify({"error": "File must be a .csv"}), 400
try:
raw = uploaded.read().decode("utf-8-sig") # strip BOM if present
except UnicodeDecodeError:
return jsonify({"error": "Could not decode file as UTF-8."}), 400
reader = csv.DictReader(io.StringIO(raw))
fieldnames = reader.fieldnames
if not fieldnames or "We" not in fieldnames or "Oh" not in fieldnames:
return jsonify(
{"error": "CSV must have 'We' and 'Oh' columns in the first (header) row."}
), 400
# Build output column order: We, Oh, beta, then any remaining columns
remaining = [col for col in fieldnames if col not in ("We", "Oh", "beta")]
out_fieldnames = ["We", "Oh", "beta"] + remaining
out_rows = []
row_errors = []
for line_num, row in enumerate(reader, start=2):
we_raw = row.get("We", "")
oh_raw = row.get("Oh", "")
validation_error = None
try:
we = float(we_raw)
oh = float(oh_raw)
validation_error = validate_theory_inputs(we, oh)
if validation_error is not None:
raise ValueError(validation_error)
pred_beta = predict_beta_from_payload(_MODEL_PAYLOAD, oh=oh, we=we)
row["beta"] = f"{pred_beta:.6f}"
except (ValueError, TypeError):
message = (
validation_error if validation_error is not None else "Invalid We/Oh inputs"
)
row_errors.append(_format_row_error(line_num, message, we_raw, oh_raw))
row["beta"] = "error"
out_rows.append(row)
if not out_rows:
return jsonify({"error": "CSV contained no data rows."}), 400
# Build response CSV
buf = io.StringIO()
writer = csv.DictWriter(
buf, fieldnames=out_fieldnames, extrasaction="ignore", lineterminator="\n"
)
writer.writeheader()
writer.writerows(out_rows)
csv_bytes = buf.getvalue().encode("utf-8")
response = Response(csv_bytes, mimetype="text/csv")
response.headers["Content-Disposition"] = "attachment; filename=SLtheory_results.csv"
if row_errors:
# Surface parse errors as a custom header (non-fatal)
response.headers["X-Row-Errors"] = "; ".join(row_errors[:10])
return response