diff --git a/TransformersSharp.Tests/SentenceTransformer.Test.cs b/TransformersSharp.Tests/SentenceTransformer.Test.cs index 8b7d4d2..1556492 100644 --- a/TransformersSharp.Tests/SentenceTransformer.Test.cs +++ b/TransformersSharp.Tests/SentenceTransformer.Test.cs @@ -7,7 +7,7 @@ public class SentenceTransformerTests [Fact] async public Task SentenceTransformer_ShouldGenerateEmbeddings() { - var transformer = SentenceTransformer.FromModel("nomic-ai/nomic-embed-text-v1.5", trustRemoteCode: true); + var transformer = SentenceTransformer.FromModel("sentence-transformers/clip-ViT-B-32-multilingual-v1", trustRemoteCode: true); Assert.NotNull(transformer); Assert.IsType(transformer); @@ -25,20 +25,33 @@ async public Task SentenceTransformer_ShouldGenerateEmbeddings() { Assert.NotNull(embedding); Assert.IsType>(embedding); - Assert.Equal(768, embedding.Vector.Length); // Assuming the model produces 768-dimensional embeddings + Assert.Equal(512, embedding.Vector.Length); // Assuming the model produces 512-dimensional embeddings }); } [Fact] - public void SentenceTransformer_ShouldGenerateSingleEmbedding() + public void SentenceTransformer_ShouldGenerateSingleEmbedding_Sentence() { - var transformer = SentenceTransformer.FromModel("nomic-ai/nomic-embed-text-v1.5", trustRemoteCode: true); + var transformer = SentenceTransformer.FromModel("sentence-transformers/clip-ViT-B-32-multilingual-v1", trustRemoteCode: true); Assert.NotNull(transformer); Assert.IsType(transformer); var sentence = "The quick brown fox jumps over the lazy dog."; - var embedding = transformer.Generate(sentence); + var embedding = transformer.GenerateSentence(sentence); Assert.NotNull(embedding); Assert.IsType(embedding); - Assert.Equal(768, embedding.Length); // Assuming the model produces 768-dimensional embeddings + Assert.Equal(512, embedding.Length); // Assuming the model produces 512-dimensional embeddings + } + + [Fact] + public void SentenceTransformer_ShouldGenerateSingleEmbedding_Image() + { + var transformer = SentenceTransformer.FromModel("clip-ViT-B-32", trustRemoteCode: true); + Assert.NotNull(transformer); + Assert.IsType(transformer); + var image_path = "https://images.unsplash.com/photo-1547494912-c69d3ad40e7f?ixlib=rb-4.1.0&q=85&fm=jpg&crop=entropy&cs=srgb&w=640"; + var embedding = transformer.GenerateImage(image_path); + Assert.NotNull(embedding); + Assert.IsType(embedding); + Assert.Equal(512, embedding.Length); // Assuming the model produces 512-dimensional embeddings } } diff --git a/TransformersSharp/SentenceTransformer.cs b/TransformersSharp/SentenceTransformer.cs index 176fb24..7923927 100644 --- a/TransformersSharp/SentenceTransformer.cs +++ b/TransformersSharp/SentenceTransformer.cs @@ -21,12 +21,18 @@ public void Dispose() transformerObject.Dispose(); } - public float[] Generate(string sentence) + public float[] GenerateSentence(string sentence) { var result = TransformerEnvironment.SentenceTransformersWrapper.EncodeSentence(transformerObject, sentence); return result.AsFloatReadOnlySpan().ToArray(); } + public float[] GenerateImage(string image_path) + { + var result = TransformerEnvironment.SentenceTransformersWrapper.EncodeImage(transformerObject, image_path); + return result.AsFloatReadOnlySpan().ToArray(); + } + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) { return Task.Run(() => diff --git a/TransformersSharp/python/sentence_transformers_wrapper.py b/TransformersSharp/python/sentence_transformers_wrapper.py index e81540e..5a52472 100644 --- a/TransformersSharp/python/sentence_transformers_wrapper.py +++ b/TransformersSharp/python/sentence_transformers_wrapper.py @@ -1,4 +1,6 @@ from sentence_transformers import SentenceTransformer +from PIL import Image +import requests from collections.abc import Buffer from typing import Optional @@ -13,6 +15,11 @@ def sentence_transformer(model: str, """ return SentenceTransformer(model, device=device, cache_folder=cache_dir, revision=revision, trust_remote_code=trust_remote_code) +def load_image(url_or_path): + if url_or_path.startswith("http://") or url_or_path.startswith("https://"): + return Image.open(requests.get(url_or_path, stream=True).raw) + else: + return Image.open(url_or_path) def encode_sentence(model: SentenceTransformer, sentence: str) -> Buffer: """ @@ -20,9 +27,16 @@ def encode_sentence(model: SentenceTransformer, sentence: str) -> Buffer: """ return model.encode([sentence])[0] - def encode_sentences(model: SentenceTransformer, sentences: list[str]) -> Buffer: """ Encode a list of sentences using the SentenceTransformer model. """ return model.encode(sentences) + +def encode_image(model: SentenceTransformer, image_path: str) -> Buffer: + """ + Encode a images using the SentenceTransformer model. + """ + image = load_image(image_path) + + return model.encode(image) \ No newline at end of file