Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 48 additions & 18 deletions openapi2cli/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ class CLIOption:
help: str = ""
is_flag: bool = False
multiple: bool = False
location: str = "query" # path, query, header, cookie, body, body_raw
api_name: str = "" # original API parameter/property name

@property
def help_literal(self) -> str:
"""Python string literal for Click help text."""
return repr(self.help or "")

@property
def default_literal(self) -> str:
"""Python string literal for default value."""
return repr(self.default)


@dataclass
Expand Down Expand Up @@ -55,6 +67,7 @@ class GeneratedCLI:
groups: List[CLIGroup] = field(default_factory=list)
global_options: List[CLIOption] = field(default_factory=list)
auth_schemes: List[AuthScheme] = field(default_factory=list)
api_key_header_name: str = ""

def to_python(self) -> str:
"""Generate Python code for the CLI."""
Expand Down Expand Up @@ -93,11 +106,12 @@ def generate(self, spec: ParsedSpec, name: str) -> GeneratedCLI:
return GeneratedCLI(
name=name,
version=spec.version,
description=spec.description or f"CLI for {spec.title}",
description=self._clean_text(spec.description) or f"CLI for {self._clean_text(spec.title)}",
base_url=spec.base_url,
groups=groups,
global_options=global_options,
auth_schemes=spec.auth_schemes,
api_key_header_name=self._api_key_header_name(spec.auth_schemes),
)

def _generate_global_options(self, spec: ParsedSpec) -> List[CLIOption]:
Expand Down Expand Up @@ -152,7 +166,7 @@ def _generate_group(self, tag: str, endpoints: List[Endpoint]) -> CLIGroup:

return CLIGroup(
name=self._sanitize_name(tag),
help=f"Commands for {tag}",
help=f"Commands for {self._clean_text(tag)}",
commands=commands,
)

Expand All @@ -174,7 +188,9 @@ def add_option(opt: CLIOption) -> None:
param_type=self._map_type(param.schema_type),
required=param.required,
default=str(param.default) if param.default is not None else None,
help=param.description or f"{param.name} parameter",
help=self._clean_text(param.description) or f"{param.name} parameter",
location=param.location,
api_name=param.name,
))

# Add options for request body properties
Expand All @@ -187,21 +203,25 @@ def add_option(opt: CLIOption) -> None:
name=f"--{self._sanitize_name(prop_name)}",
param_type=self._map_type(prop_schema.get('type', 'string')),
required=required,
help=prop_schema.get('description', f"{prop_name} field"),
help=self._clean_text(prop_schema.get('description', '')) or f"{prop_name} field",
location="body",
api_name=prop_name,
))

# Also add a --data option for raw JSON input
add_option(CLIOption(
name="--data",
param_type="str",
help="Raw JSON data for request body",
location="body_raw",
api_name="data",
))

return CLICommand(
name=endpoint.cli_name,
method=endpoint.method,
path=endpoint.path,
help=endpoint.summary or endpoint.description or f"{endpoint.method} {endpoint.path}",
help=self._clean_text(endpoint.summary) or self._clean_text(endpoint.description) or f"{endpoint.method} {endpoint.path}",
options=options,
has_body=has_body,
)
Expand Down Expand Up @@ -231,6 +251,19 @@ def _map_type(self, schema_type: str) -> str:
}
return mapping.get(schema_type, 'str')

def _clean_text(self, text: str) -> str:
"""Normalize free-text fields so they are safe in generated source strings."""
if not text:
return ""
return re.sub(r"\s+", " ", str(text)).strip()

def _api_key_header_name(self, auth_schemes: List[AuthScheme]) -> str:
"""Return the API key header name, if the spec defines one."""
for scheme in auth_schemes:
if scheme.type == "apiKey" and scheme.location == "header" and scheme.param_name:
return scheme.param_name
return ""


# Template for generated CLI - use raw strings to avoid escaping issues
CLI_TEMPLATE_STR = '''
Expand Down Expand Up @@ -271,12 +304,9 @@ def get_auth_headers(api_key: Optional[str] = None, token: Optional[str] = None)
if tok:
headers["Authorization"] = "Bearer " + tok
elif key:
{%- for scheme in cli.auth_schemes %}
{%- if scheme.type == "apiKey" and scheme.location == "header" %}
headers["{{ scheme.param_name }}"] = key
{%- endif %}
{%- endfor %}
{%- if not cli.auth_schemes %}
{%- if cli.api_key_header_name %}
headers["{{ cli.api_key_header_name }}"] = key
{%- else %}
headers["X-API-Key"] = key
{%- endif %}

Expand Down Expand Up @@ -380,7 +410,7 @@ def {{ group.name | replace("-", "_") }}():

@{{ group.name | replace("-", "_") }}.command("{{ cmd.name }}")
{%- for opt in cmd.options %}
@click.option("{{ opt.name }}"{% if opt.required %}, required=True{% endif %}{% if opt.default %}, default="{{ opt.default }}"{% endif %}, help="{{ opt.help | replace('"', '\\"') }}")
@click.option("{{ opt.name }}"{% if opt.required %}, required=True{% endif %}{% if opt.default is not none %}, default={{ opt.default_literal }}{% endif %}, help={{ opt.help_literal }})
{%- endfor %}
@click.pass_context
def {{ group.name | replace("-", "_") | replace(".", "_") }}_{{ cmd.name | replace("-", "_") | replace(".", "_") }}(ctx{% for opt in cmd.options %}, {{ opt.name | replace("--", "") | replace("-", "_") | replace(".", "_") }}{% endfor %}):
Expand All @@ -392,14 +422,14 @@ def {{ group.name | replace("-", "_") | replace(".", "_") }}_{{ cmd.name | repla
{%- for opt in cmd.options %}
{%- set var_name = opt.name | replace("--", "") | replace("-", "_") %}
if {{ var_name }} is not None:
{%- if "id" in opt.name.lower() and "{" in cmd.path %}
path_params["{{ opt.name | replace("--", "") | replace("-", "") }}"] = {{ var_name }}
{%- elif opt.name == "--data" %}
{%- if opt.location == "path" %}
path_params["{{ opt.api_name }}"] = {{ var_name }}
{%- elif opt.location == "body_raw" %}
body_data = json.loads({{ var_name }})
{%- elif cmd.has_body and opt.name != "--data" %}
body_data["{{ opt.name | replace("--", "") | replace("-", "_") }}"] = {{ var_name }}
{%- elif opt.location == "body" %}
body_data["{{ opt.api_name }}"] = {{ var_name }}
{%- else %}
query_params["{{ opt.name | replace("--", "") }}"] = {{ var_name }}
query_params["{{ opt.api_name }}"] = {{ var_name }}
{%- endif %}
{%- endfor %}

Expand Down
61 changes: 60 additions & 1 deletion tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path

from openapi2cli.generator import CLIGenerator, GeneratedCLI
from openapi2cli.parser import OpenAPIParser
from openapi2cli.parser import Endpoint, OpenAPIParser, Parameter, ParsedSpec

FIXTURES = Path(__file__).parent / "fixtures"

Expand Down Expand Up @@ -144,6 +144,65 @@ def test_generated_code_is_valid_python(self):
# Should compile without syntax errors
compile(code, "<generated>", "exec")

def test_escapes_multiline_quoted_help_text(self):
"""Escapes multiline parameter descriptions with embedded quotes."""
spec = ParsedSpec(
title="Demo",
version="1.0.0",
description="Demo API",
base_url="https://api.example.com",
endpoints=[
Endpoint(
path="/events",
method="GET",
operation_id="listEvents",
tags=["events"],
parameters=[
Parameter(
name="min_start_time",
location="query",
description='Include events after "2020-01-02T03:04:05.678Z".\nUse UTC.',
)
],
)
],
)

generator = CLIGenerator()
cli = generator.generate(spec, name="demo")
code = cli.to_python()

# Should compile without syntax errors from help string generation
compile(code, "<generated>", "exec")

def test_maps_path_and_query_params_using_openapi_names(self):
"""Uses OpenAPI param locations/names instead of CLI-name heuristics."""
spec = ParsedSpec(
title="Demo",
version="1.0.0",
description="Demo API",
base_url="https://api.example.com",
endpoints=[
Endpoint(
path="/users/{uuid}",
method="GET",
operation_id="getUser",
tags=["users"],
parameters=[
Parameter(name="uuid", location="path", required=True),
Parameter(name="min_start_time", location="query"),
],
)
],
)

generator = CLIGenerator()
cli = generator.generate(spec, name="demo")
code = cli.to_python()

assert 'path_params["uuid"] = uuid' in code
assert 'query_params["min_start_time"] = min_start_time' in code


class TestGeneratedCLI:
"""Tests for the GeneratedCLI data class."""
Expand Down