Skip to content

Commit c50c549

Browse files
authored
feat: improve file type check (#563)
1 parent dafb047 commit c50c549

File tree

7 files changed

+36
-19
lines changed

7 files changed

+36
-19
lines changed

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44

5+
- Switch from imghdr to filetype for image file type check (#563)
56
- Remove prompt lab (#549)
67
- Add batched() helper method to utils (#555)
78
- Rename DocumentMeta create_text_document_from_literal to from_literal (#561)

packages/ragbits-core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ dependencies = [
3838
"tomli>=2.0.2,<3.0.0",
3939
"litellm>=1.55.0,<2.0.0",
4040
"aiohttp>=3.10.8,<4.0.0",
41-
"standard-imghdr>=3.10.14,<3.14"
41+
"filetype>=1.2.0,<2.0.0",
4242
]
4343

4444
[project.urls]

packages/ragbits-core/src/ragbits/core/prompt/prompt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
22
import base64
3-
import imghdr
43
import textwrap
54
from abc import ABCMeta
65
from collections.abc import Awaitable, Callable
76
from typing import Any, Generic, cast, get_args, get_origin, overload
87

8+
import filetype
99
from jinja2 import Environment, Template, meta
1010
from pydantic import BaseModel
1111
from typing_extensions import TypeVar, get_original_bases
@@ -288,10 +288,10 @@ def list_images(self) -> list[str]:
288288
@staticmethod
289289
def _create_message_with_image(image: str | bytes) -> dict:
290290
if isinstance(image, bytes):
291-
image_type = imghdr.what(None, image)
292-
if not image_type:
291+
detected_type = filetype.guess(image)
292+
if not detected_type or not detected_type.mime.startswith("image/"):
293293
raise PromptWithImagesOfInvalidFormat()
294-
image_url = f"data:image/{image_type};base64,{base64.b64encode(image).decode('utf-8')}"
294+
image_url = f"data:{detected_type.mime};base64,{base64.b64encode(image).decode('utf-8')}"
295295
else:
296296
image_url = image
297297
return {

packages/ragbits-document-search/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44

5+
- Improve document file type check (#563)
56
- Fix reranker options typing (#562)
67
- Add query rephraser options (#560)
78
- Rename DocumentMeta create_text_document_from_literal to from_literal (#561)

packages/ragbits-document-search/pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ classifiers = [
3131
"Topic :: Scientific/Engineering :: Artificial Intelligence",
3232
"Topic :: Software Development :: Libraries :: Python Modules",
3333
]
34-
dependencies = ["unstructured>=0.16.9", "unstructured-client>=0.26.0", "rerankers>=0.6.1", "ragbits-core==0.17.1"]
34+
dependencies = [
35+
"unstructured>=0.16.9",
36+
"unstructured-client>=0.26.0",
37+
"rerankers>=0.6.1",
38+
"filetype>=1.2.0,<2.0.0",
39+
"ragbits-core==0.17.1",
40+
]
3541

3642
[project.urls]
3743
"Homepage" = "https://github.com/deepsense-ai/ragbits"

packages/ragbits-document-search/src/ragbits/document_search/documents/document.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
from typing import Annotated, Any
55

6+
import filetype
67
from pydantic import BaseModel
78
from typing_extensions import deprecated
89

@@ -121,7 +122,7 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta":
121122
The document metadata.
122123
"""
123124
return cls(
124-
document_type=DocumentType(local_path.suffix[1:]),
125+
document_type=cls._infer_document_type(local_path),
125126
source=LocalFileSource(path=local_path),
126127
)
127128

@@ -139,10 +140,25 @@ async def from_source(cls, source: Source) -> "DocumentMeta":
139140
path = await source.fetch()
140141

141142
return cls(
142-
document_type=DocumentType(path.suffix[1:]),
143+
document_type=cls._infer_document_type(path),
143144
source=source,
144145
)
145146

147+
@staticmethod
148+
def _infer_document_type(path: Path) -> DocumentType:
149+
"""
150+
Infer the document type by checking the file signature. Use the file extension as a fallback.
151+
152+
Args:
153+
path: The path to the file.
154+
155+
Returns:
156+
The inferred document type.
157+
"""
158+
if kind := filetype.guess(path):
159+
return DocumentType(kind.extension)
160+
return DocumentType(path.suffix[1:])
161+
146162

147163
class Document(BaseModel):
148164
"""

uv.lock

Lines changed: 4 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)