Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ It can be used anywhere a type is accepted:
```py
from typing_extensions import LiteralString

x: LiteralString

def f():
def _(x: LiteralString):
reveal_type(x) # revealed: LiteralString
```

Expand Down Expand Up @@ -64,54 +62,60 @@ class C(LiteralString): ... # error: [invalid-base]
```py
from typing_extensions import LiteralString

foo: LiteralString = "foo"
reveal_type(foo) # revealed: Literal["foo"]

bar: LiteralString = "bar"
reveal_type(foo + bar) # revealed: Literal["foobar"]

baz: LiteralString = "baz"
baz += foo
reveal_type(baz) # revealed: Literal["bazfoo"]

qux = (foo, bar)
reveal_type(qux) # revealed: tuple[Literal["foo"], Literal["bar"]]

reveal_type(foo.join(qux)) # revealed: LiteralString

template: LiteralString = "{}, {}"
reveal_type(template) # revealed: Literal["{}, {}"]
reveal_type(template.format(foo, bar)) # revealed: LiteralString
def _(literal_a: LiteralString, literal_b: LiteralString, a_str: str):
# Addition
reveal_type(literal_a + literal_b) # revealed: LiteralString
reveal_type(literal_a + a_str) # revealed: str
reveal_type(a_str + literal_a) # revealed: str

# In-place addition
combined_literal = literal_a
combined_literal += literal_b
reveal_type(combined_literal) # revealed: LiteralString
combined_non_literal1 = literal_a
combined_non_literal1 += a_str
reveal_type(combined_non_literal1) # revealed: str
combined_non_literal2 = a_str
combined_non_literal2 += literal_a
reveal_type(combined_non_literal2) # revealed: str

# Join
reveal_type(literal_a.join(("abc", "foo", literal_a, literal_b))) # revealed: LiteralString
reveal_type(a_str.join(("abc", "foo", literal_a, literal_b))) # revealed: str
reveal_type(literal_a.join(("abc", "foo", a_str))) # revealed: str

# .format(…)
reveal_type("{}, {}".format(literal_a, literal_b)) # revealed: LiteralString
reveal_type("{}, {}".format(literal_a, a_str)) # revealed: str

# f-string
reveal_type(f"{literal_a} {literal_b}") # revealed: LiteralString
reveal_type(f"{literal_a} {a_str}") # revealed: str

# Repetition
reveal_type(literal_a * 10) # revealed: LiteralString
```

### Assignability

`Literal[""]` is assignable to `LiteralString`, and `LiteralString` is assignable to `str`, but not
vice versa.
`Literal["abc"]` is assignable to `LiteralString`, and `LiteralString` is assignable to `str`, but
not vice versa.

```py
from typing_extensions import Literal, LiteralString
from ty_extensions import static_assert, is_assignable_to

def _(flag: bool):
foo_1: Literal["foo"] = "foo"
bar_1: LiteralString = foo_1 # fine

foo_2 = "foo" if flag else "bar"
reveal_type(foo_2) # revealed: Literal["foo", "bar"]
bar_2: LiteralString = foo_2 # fine
static_assert(is_assignable_to(Literal[""], LiteralString))
static_assert(is_assignable_to(Literal["abc"], LiteralString))
static_assert(is_assignable_to(Literal["abc", "def"], LiteralString))

foo_3: LiteralString = "foo" * 1_000_000_000
bar_3: str = foo_2 # fine
static_assert(not is_assignable_to(LiteralString, Literal[""]))
static_assert(not is_assignable_to(LiteralString, Literal["abc"]))
static_assert(not is_assignable_to(LiteralString, Literal["abc", "def"]))

baz_1: str = repr(object())
qux_1: LiteralString = baz_1 # error: [invalid-assignment]
static_assert(is_assignable_to(LiteralString, str))

baz_2: LiteralString = "baz" * 1_000_000_000
qux_2: Literal["qux"] = baz_2 # error: [invalid-assignment]

baz_3 = "foo" if flag else 1
reveal_type(baz_3) # revealed: Literal["foo", 1]
qux_3: LiteralString = baz_3 # error: [invalid-assignment]
static_assert(not is_assignable_to(str, LiteralString))
```

### Narrowing
Expand Down Expand Up @@ -144,9 +148,7 @@ python-version = "3.11"
```py
from typing import LiteralString

x: LiteralString = "foo"

def f():
def _(x: LiteralString):
reveal_type(x) # revealed: LiteralString
```

Expand Down
28 changes: 20 additions & 8 deletions crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11162,12 +11162,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|| !conversion.is_none()
|| format_spec.is_some()
{
collector.add_expression();
collector.add_non_literal_string_expression();
} else {
if let Some(literal) = ty.str(self.db()).as_string_literal() {
let str_ty = ty.str(self.db());
if let Some(literal) = str_ty.as_string_literal() {
collector.push_str(literal.value(self.db()));
} else if str_ty.is_literal_string() {
collector.add_literal_string_expression();
} else {
collector.add_expression();
collector.add_non_literal_string_expression();
}
}
}
Expand Down Expand Up @@ -17250,14 +17253,14 @@ fn format_import_from_module(level: u32, module: Option<&str>) -> String {
#[derive(Debug)]
struct StringPartsCollector {
concatenated: Option<String>,
expression: bool,
contains_non_literal_str: bool,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only renamed for clarity

}

impl StringPartsCollector {
fn new() -> Self {
Self {
concatenated: Some(String::new()),
expression: false,
contains_non_literal_str: false,
}
}

Expand All @@ -17274,13 +17277,22 @@ impl StringPartsCollector {
}
}

fn add_expression(&mut self) {
/// Add an expression whose `__str__` return type is `LiteralString`.
/// The exact value is unknown, so we can't track the concatenated string,
/// but the result is still `LiteralString`.
fn add_literal_string_expression(&mut self) {
self.concatenated = None;
self.expression = true;
}

/// Add an expression whose `__str__` return type is not `LiteralString`.
/// The result will degrade to `str`.
fn add_non_literal_string_expression(&mut self) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as add_expression before, just renamed.

self.concatenated = None;
self.contains_non_literal_str = true;
}

fn string_type(self, db: &dyn Db) -> Type<'_> {
if self.expression {
if self.contains_non_literal_str {
KnownClass::Str.to_instance(db)
} else if let Some(concatenated) = self.concatenated {
Type::string_literal(db, &concatenated)
Expand Down
Loading