Skip to content

Commit 1d43844

Browse files
committed
Add interactive warehouse demo with auto-play, grid visualization, and difficulty selector
1 parent f628597 commit 1d43844

File tree

4 files changed

+344
-14
lines changed

4 files changed

+344
-14
lines changed

src/envs/warehouse_env/README.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,27 @@
1+
---
2+
title: Warehouse Env Environment Server
3+
emoji: 🏭
4+
colorFrom: blue
5+
colorTo: indigo
6+
sdk: docker
7+
pinned: false
8+
app_port: 8000
9+
base_path: /demo
10+
tags:
11+
- openenv
12+
- reinforcement-learning
13+
- logistics
14+
- warehouse
15+
- robotics
16+
---
17+
118
# Warehouse Optimization Environment
219

320
A grid-based warehouse logistics optimization environment for reinforcement learning. This environment simulates a warehouse robot that must navigate through obstacles, pick up packages from pickup zones, and deliver them to designated dropoff zones while optimizing for time and efficiency.
421

522
## Overview
623

7-
The Warehouse Environment is designed for training RL agents on logistics and pathfinding tasks. It features:
24+
The Warehouse Environment is designed for training reinforcement learning agents on logistics and pathfinding tasks. It features:
825

926
- **Grid-based navigation** with walls and obstacles
1027
- **Package pickup and delivery** mechanics

src/envs/warehouse_env/server/Dockerfile

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@ FROM python:3.11-slim
22

33
WORKDIR /app
44

5-
# Copy core dependencies
6-
COPY src/core /app/core
7-
8-
# Copy warehouse environment
9-
COPY src/envs/warehouse_env /app/envs/warehouse_env
5+
# Copy all warehouse environment files (for HF Spaces deployment)
6+
COPY . /app/
107

118
# Install Python dependencies
129
RUN pip install --no-cache-dir \
@@ -26,5 +23,8 @@ ENV NUM_PACKAGES=0
2623
ENV MAX_STEPS=0
2724
ENV RANDOM_SEED=0
2825

26+
# Set Python path to include current directory
27+
ENV PYTHONPATH=/app
28+
2929
# Run the server
30-
CMD ["python", "-m", "uvicorn", "envs.warehouse_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
30+
CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]

src/envs/warehouse_env/server/app.py

Lines changed: 200 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88
import os
99

10-
from core.env_server import create_fastapi_app
10+
from core.env_server import create_app
1111
from envs.warehouse_env.models import WarehouseAction, WarehouseObservation
1212
from envs.warehouse_env.server.warehouse_environment import WarehouseEnvironment
1313
from fastapi import FastAPI
14-
from fastapi.responses import JSONResponse
14+
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
1515

1616

1717
# Get configuration from environment variables
@@ -34,11 +34,56 @@
3434
)
3535

3636

37-
# Create FastAPI app using OpenEnv's helper
38-
app = create_fastapi_app(warehouse_env, WarehouseAction, WarehouseObservation)
37+
# Create FastAPI app using OpenEnv's helper (with web interface if enabled)
38+
app = create_app(warehouse_env, WarehouseAction, WarehouseObservation, env_name="warehouse_env")
39+
40+
41+
# Add custom render endpoints
42+
@app.post("/set-difficulty")
43+
async def set_difficulty(request: dict):
44+
"""Change the difficulty level and reset the environment."""
45+
try:
46+
difficulty = int(request.get("difficulty", 2))
47+
if difficulty < 1 or difficulty > 5:
48+
return JSONResponse(
49+
status_code=400,
50+
content={"error": "Difficulty must be between 1 and 5"}
51+
)
52+
53+
# Recreate the warehouse environment with new difficulty
54+
global warehouse_env
55+
warehouse_env = WarehouseEnvironment(
56+
difficulty_level=difficulty,
57+
grid_width=None,
58+
grid_height=None,
59+
num_packages=None,
60+
max_steps=None,
61+
random_seed=None,
62+
)
63+
64+
# Reset the environment
65+
observation = warehouse_env.reset()
66+
67+
return JSONResponse(content={
68+
"success": True,
69+
"difficulty": difficulty,
70+
"grid_size": (warehouse_env.grid_width, warehouse_env.grid_height),
71+
"num_packages": warehouse_env.num_packages,
72+
"max_steps": warehouse_env.max_steps,
73+
"observation": {
74+
"step_count": observation.step_count,
75+
"packages_delivered": observation.packages_delivered,
76+
"total_packages": observation.total_packages,
77+
"robot_position": observation.robot_position,
78+
}
79+
})
80+
except Exception as e:
81+
return JSONResponse(
82+
status_code=500,
83+
content={"error": f"Failed to set difficulty: {str(e)}"}
84+
)
3985

4086

41-
# Add custom render endpoint
4287
@app.get("/render")
4388
async def render():
4489
"""Get ASCII visualization of warehouse state."""
@@ -50,6 +95,143 @@ async def render():
5095
status_code=500, content={"error": f"Failed to render: {str(e)}"}
5196
)
5297

98+
@app.get("/render/html")
99+
async def render_html():
100+
"""Get HTML visualization of warehouse state."""
101+
try:
102+
html_content = warehouse_env.render_html()
103+
return HTMLResponse(content=html_content)
104+
except Exception as e:
105+
return JSONResponse(
106+
status_code=500, content={"error": f"Failed to render HTML: {str(e)}"}
107+
)
108+
109+
@app.post("/auto-step")
110+
async def auto_step():
111+
"""Execute one step using a greedy agent."""
112+
try:
113+
# Get current observation
114+
if warehouse_env.is_done:
115+
return JSONResponse(content={
116+
"done": True,
117+
"message": "Episode finished. Reset to start a new episode."
118+
})
119+
120+
# Simple greedy policy
121+
action_id = _get_greedy_action()
122+
action = WarehouseAction(action_id=action_id)
123+
124+
# Execute step
125+
result = warehouse_env.step(action)
126+
127+
return JSONResponse(content={
128+
"action": action.action_name,
129+
"message": result.message,
130+
"reward": result.reward,
131+
"done": result.done,
132+
"step_count": result.step_count,
133+
"packages_delivered": result.packages_delivered,
134+
"robot_position": result.robot_position,
135+
})
136+
except Exception as e:
137+
return JSONResponse(
138+
status_code=500, content={"error": f"Failed to execute auto-step: {str(e)}"}
139+
)
140+
141+
def _get_greedy_action() -> int:
142+
"""Simple greedy policy with obstacle avoidance."""
143+
robot_x, robot_y = warehouse_env.robot_position
144+
145+
# Determine target location
146+
if warehouse_env.robot_carrying is None:
147+
# Not carrying: move toward nearest waiting package
148+
target = None
149+
min_dist = float('inf')
150+
151+
for package in warehouse_env.packages:
152+
if package.status == "waiting":
153+
px, py = package.pickup_location
154+
dist = abs(robot_x - px) + abs(robot_y - py)
155+
if dist < min_dist:
156+
min_dist = dist
157+
target = (px, py)
158+
159+
if target is None:
160+
return 4 # Try to pick up if at location
161+
162+
target_x, target_y = target
163+
else:
164+
# Carrying: move toward dropoff zone
165+
package = next((p for p in warehouse_env.packages if p.id == warehouse_env.robot_carrying), None)
166+
if package:
167+
target_x, target_y = package.dropoff_location
168+
else:
169+
return 5 # Try to drop off
170+
171+
# Check if at target location
172+
if robot_x == target_x and robot_y == target_y:
173+
return 4 if warehouse_env.robot_carrying is None else 5
174+
175+
# Try to move toward target, checking for obstacles
176+
# Priority: move on axis with larger distance first
177+
dx = target_x - robot_x
178+
dy = target_y - robot_y
179+
180+
# List of possible moves in order of preference
181+
moves = []
182+
183+
if abs(dx) > abs(dy):
184+
# Prioritize horizontal movement
185+
if dx > 0:
186+
moves.append((3, robot_x + 1, robot_y)) # RIGHT
187+
elif dx < 0:
188+
moves.append((2, robot_x - 1, robot_y)) # LEFT
189+
190+
if dy > 0:
191+
moves.append((1, robot_x, robot_y + 1)) # DOWN
192+
elif dy < 0:
193+
moves.append((0, robot_x, robot_y - 1)) # UP
194+
else:
195+
# Prioritize vertical movement
196+
if dy > 0:
197+
moves.append((1, robot_x, robot_y + 1)) # DOWN
198+
elif dy < 0:
199+
moves.append((0, robot_x, robot_y - 1)) # UP
200+
201+
if dx > 0:
202+
moves.append((3, robot_x + 1, robot_y)) # RIGHT
203+
elif dx < 0:
204+
moves.append((2, robot_x - 1, robot_y)) # LEFT
205+
206+
# Add perpendicular moves as fallback
207+
if dx == 0 and dy != 0:
208+
moves.append((3, robot_x + 1, robot_y)) # RIGHT
209+
moves.append((2, robot_x - 1, robot_y)) # LEFT
210+
elif dy == 0 and dx != 0:
211+
moves.append((1, robot_x, robot_y + 1)) # DOWN
212+
moves.append((0, robot_x, robot_y - 1)) # UP
213+
214+
# Try moves in order until we find a valid one
215+
WALL = 1
216+
SHELF = 2
217+
218+
for action_id, new_x, new_y in moves:
219+
# Check bounds
220+
if 0 <= new_x < warehouse_env.grid_width and 0 <= new_y < warehouse_env.grid_height:
221+
# Check if cell is passable
222+
if warehouse_env.grid[new_y][new_x] not in [WALL, SHELF]:
223+
return action_id
224+
225+
# If no valid move toward target, try any valid move
226+
for action_id, dx, dy in [(0, 0, -1), (1, 0, 1), (2, -1, 0), (3, 1, 0)]:
227+
new_x, new_y = robot_x + dx, robot_y + dy
228+
if 0 <= new_x < warehouse_env.grid_width and 0 <= new_y < warehouse_env.grid_height:
229+
if warehouse_env.grid[new_y][new_x] not in [WALL, SHELF]:
230+
return action_id
231+
232+
# Last resort: try pickup/dropoff
233+
return 4 if warehouse_env.robot_carrying is None else 5
234+
53235

54236
# Add health check endpoint
55237
@app.get("/health")
@@ -65,6 +247,17 @@ async def health():
65247
}
66248

67249

250+
@app.get("/demo")
251+
async def demo():
252+
"""Serve the interactive demo page."""
253+
import pathlib
254+
demo_path = pathlib.Path(__file__).parent / "demo.html"
255+
if demo_path.exists():
256+
return FileResponse(demo_path)
257+
else:
258+
return HTMLResponse(content="<h1>Demo page not found</h1><p>Please check the server configuration.</p>")
259+
260+
68261
if __name__ == "__main__":
69262
import uvicorn
70263

@@ -74,8 +267,8 @@ def main():
74267
"""Entry point for warehouse-server command."""
75268
import uvicorn
76269
import os
77-
270+
78271
port = int(os.getenv("PORT", "8000"))
79272
host = os.getenv("HOST", "0.0.0.0")
80-
273+
81274
uvicorn.run(app, host=host, port=port)

0 commit comments

Comments
 (0)