Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 174 additions & 57 deletions jupyterlab_tinyapp/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from .generation.generator import MockStreamingGenerator, OpenAIStreamingGenerator
from .generation.streaming import StreamParser
import ldap3
import requests
from urllib.parse import urlencode

app = Application.instance()
logger = logging.getLogger(app.log.name)
Expand Down Expand Up @@ -133,6 +135,15 @@ class TinyAppServerEndpoint(Enum):
LDAP_BIND_PASSWORD = os.getenv('LDAP_BIND_PASSWORD')
LDAP_USER_SEARCH_FILTER = os.getenv('LDAP_USER_SEARCH_FILTER', '(|(givenName=*{search}*)(sn=*{search}*))')
LDAP_USER_ATTRIBUTES = os.getenv('LDAP_USER_ATTRIBUTES', 'cn,uid,displayName,mail,givenName,sn').split(',')
LDAP_MAX_RESULTS = int(os.getenv('LDAP_MAX_RESULTS', '1000'))

# SCIM Configuration
USER_PROVIDER = os.getenv('USER_PROVIDER', 'ldap') # 'ldap' or 'scim'
SCIM_ENDPOINT = os.getenv('SCIM_ENDPOINT') # e.g., https://dev-12345.okta.com/api/v1/users
SCIM_TOKEN = os.getenv('SCIM_TOKEN') # API token for SCIM endpoint
SCIM_SEARCH_FILTER = os.getenv('SCIM_SEARCH_FILTER', 'profile.firstName sw "{search}" or profile.lastName sw "{search}" or profile.login sw "{search}"')
SCIM_MAX_RESULTS = int(os.getenv('SCIM_MAX_RESULTS', '100'))
SCIM_TIMEOUT = int(os.getenv('SCIM_TIMEOUT', '30'))

# VOLUME_CLAIM_NAME will be mounted on published app container. It is assumed that
# BASE_DIR and VOLUME_CLAIM_NAME refer to same file system - otherwise files
Expand Down Expand Up @@ -875,83 +886,189 @@ async def post(self):
}))


def get_ldap_attr(entry, attr_name):
"""Safely extract LDAP attribute value, handling multi-valued attributes"""
if hasattr(entry, attr_name):
attr_value = getattr(entry, attr_name)
if attr_value:
# Handle multi-valued attributes by taking the first value
if isinstance(attr_value, list) and len(attr_value) > 0:
first_value = attr_value[0]
else:
first_value = attr_value

# Convert bytes to string if necessary
if isinstance(first_value, bytes):
first_value = first_value.decode('utf-8')

# Return stripped string
return str(first_value).strip()
return ''


async def search_users_ldap(search_query):
"""Search users using LDAP"""
if not LDAP_ADDR or not LDAP_BASE_DN:
logger.warning('LDAP not configured')
return None, 'ldap is not configured: missing LDAP_ADDR or LDAP_BASE_DN'

# Create LDAP server object
try:
server = ldap3.Server(LDAP_ADDR, get_info=ldap3.ALL)
# Connect and bind to LDAP server
if LDAP_BIND_DN and LDAP_BIND_PASSWORD:
conn = ldap3.Connection(server, LDAP_BIND_DN, LDAP_BIND_PASSWORD, auto_bind=True)
else:
conn = ldap3.Connection(server, auto_bind=True)
except Exception as e:
logger.error(f'Error connecting to LDAP server: {str(e)}')
return None, 'Unable to connect to LDAP server'

# Search for users
search_filter = LDAP_USER_SEARCH_FILTER.format(search=ldap3.utils.conv.escape_filter_chars(search_query))

try:
success = conn.search(
search_base=LDAP_BASE_DN,
search_filter=search_filter,
search_scope=ldap3.SUBTREE,
attributes=LDAP_USER_ATTRIBUTES,
size_limit=LDAP_MAX_RESULTS
)

if not success:
conn.unbind()
logger.info('No LDAP results found')
return [], None
except Exception as e:
conn.unbind()
logger.error(f'Error during LDAP search: {str(e)}')
return None, 'error searching ldap directory'

# Process search results
users = []
for entry in conn.entries:
user_data = {
'uid': get_ldap_attr(entry, 'uid'),
'cn': get_ldap_attr(entry, 'cn'),
}

user_data['label'] = user_data['cn']
user_data['value'] = user_data['uid']

if user_data['value']: # Only include users with a valid identifier
users.append(user_data)

conn.unbind()
logger.info(f'Found {len(users)} LDAP users for query: {search_query}')
return users, None


async def search_users_scim(search_query):
"""Search users using SCIM API"""
if not SCIM_ENDPOINT or not SCIM_TOKEN:
logger.warning('SCIM not configured')
return None, 'SCIM is not configured: missing SCIM_ENDPOINT or SCIM_TOKEN'

try:
# Build SCIM query parameters
filter_expr = SCIM_SEARCH_FILTER.format(search=search_query)
params = {
'filter': filter_expr,
'count': SCIM_MAX_RESULTS,
'startIndex': 1
}

headers = {
'Authorization': f'SSWS {SCIM_TOKEN}',
'Accept': 'application/scim+json',
'Content-Type': 'application/scim+json'
}

# Make SCIM API request
response = requests.get(
SCIM_ENDPOINT,
params=params,
headers=headers,
timeout=SCIM_TIMEOUT
)

if response.status_code != 200:
logger.error(f'SCIM API error: {response.status_code} {response.text}')
return None, f'SCIM API error: {response.status_code}'

scim_data = response.json()

# Process SCIM results
users = []
resources = scim_data.get('Resources', [])

for user in resources:
# Extract user data from SCIM response
user_data = {
'uid': user.get('userName', ''),
'cn': user.get('displayName', ''),
}

# Fallback to profile data if needed
if not user_data['cn'] and 'profile' in user:
profile = user['profile']
first_name = profile.get('firstName', '')
last_name = profile.get('lastName', '')
if first_name or last_name:
user_data['cn'] = f"{first_name} {last_name}".strip()

# Set label and value for UI
user_data['label'] = user_data['cn'] or user_data['uid']
user_data['value'] = user_data['uid']

if user_data['value']: # Only include users with a valid identifier
users.append(user_data)

logger.info(f'Found {len(users)} SCIM users for query: {search_query}')
return users, None

except requests.exceptions.RequestException as e:
logger.error(f'SCIM request error: {str(e)}')
return None, 'Error connecting to SCIM endpoint'
except Exception as e:
logger.error(f'SCIM processing error: {str(e)}')
return None, 'Error processing SCIM response'


class SearchUsersHandler(CustomAPIHandler):
@tornado.web.authenticated
async def get(self):
logger.info('Received request to SearchUsersHandler')
logger.info(f'Received request to SearchUsersHandler (provider: {USER_PROVIDER})')

# Get search query parameter
search_query = self.get_argument('query', '')
if not search_query or len(search_query.strip()) < 2:
logger.info('Invalid query parameter: must be at least 2 characters')
self._return_error(400, 'query parameter must be at least 2 characters')

if not LDAP_ADDR or not LDAP_BASE_DN:
logger.warning('LDAP not configured')
self._return_error(500, 'ldap is not configured: missing LDAP_ADDR or LDAP_BASE_DN')
return

# Create LDAP server object
try:
server = ldap3.Server(LDAP_ADDR, get_info=ldap3.ALL)
# Connect and bind to LDAP server
if LDAP_BIND_DN and LDAP_BIND_PASSWORD:
conn = ldap3.Connection(server, LDAP_BIND_DN, LDAP_BIND_PASSWORD, auto_bind=True)
else:
conn = ldap3.Connection(server, auto_bind=True)
except Exception as e:
logger.error(f'Error connecting to LDAP server: {str(e)}')
self._return_error(500, 'Unable to connect to LDAP server')
# Route to appropriate search function based on configuration
if USER_PROVIDER.lower() == 'scim':
users, error = await search_users_scim(search_query)
elif USER_PROVIDER.lower() == 'ldap':
users, error = await search_users_ldap(search_query)
else:
logger.error(f'Unknown user provider: {USER_PROVIDER}')
self._return_error(500, f'Unknown user provider: {USER_PROVIDER}. Must be "ldap" or "scim"')
return

# Search for users
search_filter = LDAP_USER_SEARCH_FILTER.format(search=ldap3.utils.conv.escape_filter_chars(search_query))

try:
success = conn.search(
search_base=LDAP_BASE_DN,
search_filter=search_filter,
search_scope=ldap3.SUBTREE,
attributes=LDAP_USER_ATTRIBUTES,
size_limit=1000
)

if not success:
conn.unbind()
logger.error('No results found or error during LDAP search')
self._return_error(500, 'No results found or error during LDAP search')
return
except Exception as e:
conn.unbind()
logger.error(f'Error during LDAP search: {str(e)}')
self._return_error(500, 'error searching ldap directory')
# Handle errors
if error:
self._return_error(500, error)
return

# Process search results
users = []
for entry in conn.entries:
user_data = {
'uid': str(entry.uid) if hasattr(entry, 'uid') and entry.uid else '',
'cn': str(entry.cn) if hasattr(entry, 'cn') and entry.cn else '',
'displayName': str(entry.displayName) if hasattr(entry, 'displayName') and entry.displayName else '',
'mail': str(entry.mail) if hasattr(entry, 'mail') and entry.mail else ''
}

user_data['label'] = user_data['cn']
user_data['value'] = user_data['uid']

if user_data['value']: # Only include users with a valid identifier
users.append(user_data)

logger.info(f'Found {len(users)} users for query: {search_query}')

# Return results
self.finish(json.dumps({
'data': {
'users': users
}
}))

conn.unbind()


class PingHandler(CustomAPIHandler):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"openai==2.6.1",
"aiofiles==23.2.1",
"ldap3>=2.9.1",
"requests>=2.25.0",
]
dynamic = ["version", "description", "authors", "urls", "keywords"]

Expand Down