Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions app/services/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import joblib
import numpy as np
import pandas as pd
from loguru import logger
from sklearn.linear_model import SGDRegressor
from sklearn.preprocessing import StandardScaler

try:
from services.features import get_feature_columns, make_features
Expand Down
160 changes: 66 additions & 94 deletions app/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_status_endpoint(client):
"""Test the /status endpoint."""
response = client.get("/status")
assert response.status_code == 200

data = response.json()
assert "models" in data
assert isinstance(data["models"], list)
Expand All @@ -27,36 +27,18 @@ def test_status_endpoint(client):
def test_train_endpoint_success(client):
"""Test the /train endpoint with valid data."""
# Mock the CSV reading and training
mock_df = pd.DataFrame(
{
"tenant_id": ["tenant_a"] * 15,
"product": ["apples"] * 15,
"date": pd.date_range("2024-01-01", periods=15),
"units_sold": [10, 12, 15, 8, 20, 18, 22, 25, 19, 30, 28, 35, 32, 40, 38],
"price": [
1.0,
1.1,
1.05,
0.95,
1.2,
1.15,
1.25,
1.3,
1.18,
1.4,
1.35,
1.45,
1.42,
1.5,
1.48,
],
}
)

with patch("pandas.read_csv", return_value=mock_df):
with patch("app.main.train_on_dataframe") as mock_train:
mock_df = pd.DataFrame({
'tenant_id': ['tenant_a'] * 15,
'product': ['apples'] * 15,
'date': pd.date_range('2024-01-01', periods=15),
'units_sold': [10, 12, 15, 8, 20, 18, 22, 25, 19, 30, 28, 35, 32, 40, 38],
'price': [1.0, 1.1, 1.05, 0.95, 1.2, 1.15, 1.25, 1.3, 1.18, 1.4, 1.35, 1.45, 1.42, 1.5, 1.48]
})

with patch('pandas.read_csv', return_value=mock_df):
with patch('app.main.train_on_dataframe') as mock_train:
response = client.post("/train")

assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
Expand All @@ -67,9 +49,9 @@ def test_train_endpoint_success(client):
def test_train_endpoint_handles_error(client):
"""Test the /train endpoint when training fails."""
# Mock read_csv to raise an exception
with patch("pandas.read_csv", side_effect=Exception("File not found")):
with patch('pandas.read_csv', side_effect=Exception("File not found")):
response = client.post("/train")

assert response.status_code == 500
data = response.json()
assert "detail" in data
Expand All @@ -78,13 +60,14 @@ def test_train_endpoint_handles_error(client):
def test_predict_endpoint_all_targets(client):
"""Test the /predict endpoint with all valid target options."""
targets = ["next_week", "this_month", "next_month"]

for target in targets:
response = client.get(
"/predict",
params={"tenant_id": "tenant_a", "product": "apples", "target": target},
)

response = client.get("/predict", params={
"tenant_id": "tenant_a",
"product": "apples",
"target": target
})

assert response.status_code == 200
data = response.json()
assert data["tenant_id"] == "tenant_a"
Expand All @@ -99,34 +82,33 @@ def test_predict_endpoint_different_tenants_products(client):
test_cases = [
("tenant_a", "apples"),
("tenant_b", "bananas"),
("tenant_c", "laptops"),
("tenant_c", "laptops")
]

for tenant_id, product in test_cases:
response = client.get(
"/predict",
params={"tenant_id": tenant_id, "product": product, "target": "next_week"},
)

# Some combinations might not exist, so we accept 200 or 404
response = client.get("/predict", params={
"tenant_id": tenant_id,
"product": product,
"target": "next_week"
})

assert response.status_code in [200, 404]


def test_predict_endpoint_forecast_structure(client):
"""Test the structure of forecast data."""
response = client.get(
"/predict",
params={"tenant_id": "tenant_a", "product": "apples", "target": "next_week"},
)

response = client.get("/predict", params={
"tenant_id": "tenant_a",
"product": "apples",
"target": "next_week"
})

if response.status_code == 200:
data = response.json()
forecast = data["forecast"]

# next_week should return 7 days

assert len(forecast) == 7

# Check structure of each forecast item

for day in forecast:
assert "date" in day
assert "predicted_units" in day
Expand All @@ -135,85 +117,75 @@ def test_predict_endpoint_forecast_structure(client):

def test_predict_endpoint_this_month_structure(client):
"""Test the structure of this_month forecast."""
response = client.get(
"/predict",
params={"tenant_id": "tenant_a", "product": "apples", "target": "this_month"},
)

response = client.get("/predict", params={
"tenant_id": "tenant_a",
"product": "apples",
"target": "this_month"
})

if response.status_code == 200:
data = response.json()
forecast = data["forecast"]

# this_month should return remaining days in current month

assert len(forecast) > 0

for day in forecast:
assert "date" in day
assert "predicted_units" in day


def test_predict_endpoint_next_month_structure(client):
"""Test the structure of next_month forecast."""
response = client.get(
"/predict",
params={"tenant_id": "tenant_a", "product": "apples", "target": "next_month"},
)

response = client.get("/predict", params={
"tenant_id": "tenant_a",
"product": "apples",
"target": "next_month"
})

if response.status_code == 200:
data = response.json()
forecast = data["forecast"]

# next_month should return all days in next month

assert len(forecast) >= 28 # At least February
assert len(forecast) <= 31 # At most 31 days

assert len(forecast) <= 31 # At most 31 days
for day in forecast:
assert "date" in day
assert "predicted_units" in day


def test_cors_middleware(client):
"""Test that CORS middleware is properly configured."""
# Test preflight request

response = client.options("/predict")
# OPTIONS requests might return 405 Method Not Allowed, which is fine
assert response.status_code in [200, 405]

# Check CORS headers if response is successful
if response.status_code == 200:
assert "access-control-allow-origin" in response.headers
assert response.headers["access-control-allow-origin"] == "*"


def test_app_startup(client):
"""Test that app startup loads models and starts polling."""
# The fixture already tests startup by creating the client
# This test explicitly verifies the startup behavior
response = client.get("/status")
assert response.status_code == 200

# Should have loaded some models

data = response.json()
assert len(data["models"]) > 0


def test_predict_endpoint_parameter_validation(client):
"""Test various parameter combinations for predict endpoint."""
# Test with missing tenant_id
response = client.get(
"/predict", params={"product": "apples", "target": "next_week"}
)

response = client.get("/predict", params={"product": "apples", "target": "next_week"})
assert response.status_code == 422

# Test with missing product
response = client.get(
"/predict", params={"tenant_id": "tenant_a", "target": "next_week"}
)

response = client.get("/predict", params={"tenant_id": "tenant_a", "target": "next_week"})
assert response.status_code == 422

# Test with empty strings
response = client.get(
"/predict", params={"tenant_id": "", "product": "apples", "target": "next_week"}
)
# This should return 404 (model not found) rather than validation error
response = client.get("/predict", params={
"tenant_id": "",
"product": "apples",
"target": "next_week"
})
assert response.status_code == 404
Loading
Loading