Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions TransformersSharp.Tests/SentenceTransformer.Test.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public class SentenceTransformerTests
[Fact]
async public Task SentenceTransformer_ShouldGenerateEmbeddings()
{
var transformer = SentenceTransformer.FromModel("nomic-ai/nomic-embed-text-v1.5", trustRemoteCode: true);
var transformer = SentenceTransformer.FromModel("sentence-transformers/clip-ViT-B-32-multilingual-v1", trustRemoteCode: true);
Assert.NotNull(transformer);
Assert.IsType<SentenceTransformer>(transformer);

Expand All @@ -25,20 +25,33 @@ async public Task SentenceTransformer_ShouldGenerateEmbeddings()
{
Assert.NotNull(embedding);
Assert.IsType<Embedding<float>>(embedding);
Assert.Equal(768, embedding.Vector.Length); // Assuming the model produces 768-dimensional embeddings
Assert.Equal(512, embedding.Vector.Length); // Assuming the model produces 512-dimensional embeddings
});
}

[Fact]
public void SentenceTransformer_ShouldGenerateSingleEmbedding()
public void SentenceTransformer_ShouldGenerateSingleEmbedding_Sentence()
{
var transformer = SentenceTransformer.FromModel("nomic-ai/nomic-embed-text-v1.5", trustRemoteCode: true);
var transformer = SentenceTransformer.FromModel("sentence-transformers/clip-ViT-B-32-multilingual-v1", trustRemoteCode: true);
Assert.NotNull(transformer);
Assert.IsType<SentenceTransformer>(transformer);
var sentence = "The quick brown fox jumps over the lazy dog.";
var embedding = transformer.Generate(sentence);
var embedding = transformer.GenerateSentence(sentence);
Assert.NotNull(embedding);
Assert.IsType<float[]>(embedding);
Assert.Equal(768, embedding.Length); // Assuming the model produces 768-dimensional embeddings
Assert.Equal(512, embedding.Length); // Assuming the model produces 512-dimensional embeddings
}

[Fact]
public void SentenceTransformer_ShouldGenerateSingleEmbedding_Image()
{
var transformer = SentenceTransformer.FromModel("clip-ViT-B-32", trustRemoteCode: true);
Assert.NotNull(transformer);
Assert.IsType<SentenceTransformer>(transformer);
var image_path = "https://images.unsplash.com/photo-1547494912-c69d3ad40e7f?ixlib=rb-4.1.0&q=85&fm=jpg&crop=entropy&cs=srgb&w=640";
var embedding = transformer.GenerateImage(image_path);
Assert.NotNull(embedding);
Assert.IsType<float[]>(embedding);
Assert.Equal(512, embedding.Length); // Assuming the model produces 512-dimensional embeddings
}
}
8 changes: 7 additions & 1 deletion TransformersSharp/SentenceTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@ public void Dispose()
transformerObject.Dispose();
}

public float[] Generate(string sentence)
public float[] GenerateSentence(string sentence)
{
var result = TransformerEnvironment.SentenceTransformersWrapper.EncodeSentence(transformerObject, sentence);
return result.AsFloatReadOnlySpan().ToArray();
}

public float[] GenerateImage(string image_path)
{
var result = TransformerEnvironment.SentenceTransformersWrapper.EncodeImage(transformerObject, image_path);
return result.AsFloatReadOnlySpan().ToArray();
}

public Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
{
return Task.Run(() =>
Expand Down
16 changes: 15 additions & 1 deletion TransformersSharp/python/sentence_transformers_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from sentence_transformers import SentenceTransformer
from PIL import Image
import requests
from collections.abc import Buffer
from typing import Optional

Expand All @@ -13,16 +15,28 @@ def sentence_transformer(model: str,
"""
return SentenceTransformer(model, device=device, cache_folder=cache_dir, revision=revision, trust_remote_code=trust_remote_code)

def load_image(url_or_path):
if url_or_path.startswith("http://") or url_or_path.startswith("https://"):
return Image.open(requests.get(url_or_path, stream=True).raw)
else:
return Image.open(url_or_path)

def encode_sentence(model: SentenceTransformer, sentence: str) -> Buffer:
"""
Encode a list of sentences using the SentenceTransformer model.
"""
return model.encode([sentence])[0]


def encode_sentences(model: SentenceTransformer, sentences: list[str]) -> Buffer:
"""
Encode a list of sentences using the SentenceTransformer model.
"""
return model.encode(sentences)

def encode_image(model: SentenceTransformer, image_path: str) -> Buffer:
"""
Encode a images using the SentenceTransformer model.
"""
image = load_image(image_path)

return model.encode(image)
Loading