diff --git a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Result/ResultNext.cs b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Result/ResultNext.cs index d248a1feb..63952f54e 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Result/ResultNext.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Result/ResultNext.cs @@ -38,6 +38,7 @@ public override async Task Process() } catch (TimeZoneNotFoundException tz) { + var message = tz.Message + " This can happen if the server is set to a timezone that is not recognized by .NET"; throw new DriverExceptionWrapper(tz); } } diff --git a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Session/JsonCypherParameterParser.cs b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Session/JsonCypherParameterParser.cs index 82eaaa919..890c3fb8c 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Session/JsonCypherParameterParser.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Session/JsonCypherParameterParser.cs @@ -68,6 +68,15 @@ public static CypherToNativeObject ExtractParameterFromProperty(JObject paramete }; } + if (parameter["name"].Value() == "CypherVector") + { + return new CypherToNativeObject() + { + name = parameter["name"].Value(), + data = parameter["data"].ToObject() + }; + } + return new CypherToNativeObject { name = parameter["name"].Value(), diff --git a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Session/SessionRun.cs b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Session/SessionRun.cs index 058800f3f..ad47b0a79 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Session/SessionRun.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Protocol/Session/SessionRun.cs @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System; using System.Collections.Generic; using System.Threading.Tasks; using Neo4j.Driver.Tests.TestBackend.Protocol.JsonConverters; @@ -38,15 +39,26 @@ public override async Task Process() data.TransactionConfig) .ConfigureAwait(false); - var result = ProtocolObjectFactory.CreateObject(); - result.ResultCursor = cursor; + Result.Result result = null; + try + { + result = ProtocolObjectFactory.CreateObject(); + result.ResultCursor = cursor; + } + catch(Exception ex) + { + throw; + } ResultId = result.uniqueId; } public override string Respond() { - return ((Result.Result)ObjManager.GetObject(ResultId)).Respond(); + var protocolObject = (Result.Result)ObjManager.GetObject(ResultId); + var type = protocolObject.GetType(); + var response = protocolObject.Respond(); + return response; } [JsonConverter(typeof(SessionTypeJsonConverter))] diff --git a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/SupportedFeatures.cs b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/SupportedFeatures.cs index 30ac2c27e..eff1bbae4 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/SupportedFeatures.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/SupportedFeatures.cs @@ -50,6 +50,7 @@ static SupportedFeatures() "Feature:API:SSLSchemes", "Feature:API:Summary:GqlStatusObjects", "Feature:API:Type.Temporal", + "Feature:API:Type.Vector", "Feature:Auth:Bearer", "Feature:Auth:Custom", "Feature:Auth:Kerberos", @@ -68,6 +69,7 @@ static SupportedFeatures() "Feature:Bolt:5.6", "Feature:Bolt:5.7", "Feature:Bolt:5.8", + "Feature:Bolt:6.0", "Feature:Bolt:Patch:UTC", "Feature:Bolt:HandshakeManifestV1", "Feature:Impersonation", diff --git a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Types/CypherToNative.cs b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Types/CypherToNative.cs index 6118188a8..66aec78a9 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Types/CypherToNative.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Types/CypherToNative.cs @@ -17,6 +17,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using Neo4j.Driver.Internal.Util; using Neo4j.Driver.Tests.TestBackend.Protocol.Session; using Newtonsoft.Json.Linq; @@ -66,6 +67,12 @@ public class DateTimeParameterValue public string timezone_id { get; set; } } +public class VectorParameterValue +{ + public string? dtype { get; set; } + public string? data { get; set; } +} + public class DurationParameterValue { public long? months { get; set; } @@ -95,6 +102,7 @@ internal class CypherToNative { "CypherLocalDateTime", typeof(LocalDateTime) }, { "CypherDuration", typeof(Duration) }, { "CypherPoint", typeof(Point) }, + { "CypherVector", typeof(Vector) }, { "CypherNode", typeof(INode) }, { "CypherRelationship", typeof(IRelationship) }, @@ -120,6 +128,7 @@ internal class CypherToNative { typeof(LocalDateTime), CypherDateTime }, { typeof(Duration), CypherDuration }, { typeof(Point), CypherTODO }, + { typeof(Vector), CypherVector }, { typeof(INode), CypherTODO }, { typeof(IRelationship), CypherTODO }, @@ -143,7 +152,7 @@ public static object Convert(CypherToNativeObject sourceObject) catch { throw new IOException( - $"Attempting to convert an unsuported object type to a CypherType: {sourceObject.GetType()}"); + $"Attempting to convert an unsupported object type to a CypherType: {sourceObject.GetType()}"); } } @@ -241,6 +250,43 @@ private static object CypherDateTime(Type objectType, CypherToNativeObject obj) dataTimeParam.nanosecond.Value); } + private static object CypherVector(Type objectType, CypherToNativeObject obj) + { + var data = (VectorParameterValue)obj.data; + var byteArray = ConvertStringToBytes(data.data); + var vectorType = SupportedTypeNames[data.dtype!]; + var typedArray = BytesToTypedArrayHelper.ConvertBytesToTypedArray(byteArray, vectorType); + return Vector.CreateDynamic(typedArray, byteArray); + } + + private static byte[] ConvertStringToBytes(string hexString) + { + if (string.IsNullOrEmpty(hexString)) + { + return Array.Empty(); + } + + // string looks like "7f c0 23 3a" etc. + var hexValues = hexString.Split(' ', StringSplitOptions.RemoveEmptyEntries); + var byteArray = new byte[hexValues.Length]; + for (var i = 0; i < hexValues.Length; i++) + { + byteArray[i] = System.Convert.ToByte(hexValues[i], 16); + } + + return byteArray; + } + + internal static readonly Dictionary SupportedTypeNames = new() + { + ["i8"] = typeof(sbyte), + ["i16"] = typeof(short), + ["i32"] = typeof(int), + ["i64"] = typeof(long), + ["f32"] = typeof(float), + ["f64"] = typeof(double) + }; + private static object CypherDuration(Type objectType, CypherToNativeObject obj) { var duration = obj.data as DurationParameterValue; diff --git a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Types/NativeToCypher.cs b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Types/NativeToCypher.cs index 1caaeaf97..20a4266c9 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Types/NativeToCypher.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests.TestBackend/Types/NativeToCypher.cs @@ -17,6 +17,8 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using Neo4j.Driver.Internal.IO.ValueSerializers.VectorSerializers; +using Neo4j.Driver.Tests.TestBackend.Protocol.Result; #pragma warning disable CS0618 // Type or member is obsolete - but still needs to be handled @@ -40,6 +42,7 @@ internal static class NativeToCypher { { typeof(List), CypherList }, { typeof(Dictionary), CypherMap }, + { typeof(IVector), CypherVector }, { typeof(bool), CypherSimple }, { typeof(long), CypherSimple }, @@ -67,12 +70,17 @@ public static object Convert(object sourceObject) return new NativeToCypherObject { name = "CypherNull" }; } - if (sourceObject as List != null) + if(sourceObject is IVector) + { + return FunctionMap[typeof(IVector)]("CypherVector", sourceObject); + } + + if (sourceObject is List) { return FunctionMap[typeof(List)]("CypherList", sourceObject); } - if (sourceObject as Dictionary != null) + if (sourceObject is Dictionary) { return FunctionMap[typeof(Dictionary)]("CypherMap", sourceObject); } @@ -187,6 +195,37 @@ public static NativeToCypherObject CypherList(string cypherType, object obj) { name = cypherType, data = new NativeToCypherObject.DataType { value = result } }; } + internal static readonly Dictionary VectorTypeMap = new() + { + [typeof(sbyte)] = "i8", + [typeof(short)] = "i16", + [typeof(int)] = "i32", + [typeof(long)] = "i64", + [typeof(float)] = "f32", + [typeof(double)] = "f64" + }; + + public static NativeToCypherObject CypherVector(string cypherType, object obj) + { + var vector = (IVector)obj; + Dictionary result = new() + { + ["dtype"] = VectorTypeMap[vector.ElementType], + ["data"] = ByteStreamToHexString(VectorSerializer.GetByteStream(vector)) + }; + + return new NativeToCypherObject() + { + data = result, + name = cypherType + }; + } + + private static string ByteStreamToHexString(byte[] byteStream) + { + return string.Join(" ", byteStream.Select(b => b.ToString("x2"))); + } + public static NativeToCypherObject CypherTODO(string name, object obj) { throw new NotImplementedException($"NativeToCypher : {name} conversion is not implemented yet"); diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Internal/IO/ValueSerializers/VectorSerializerTests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Internal/IO/ValueSerializers/VectorSerializerTests.cs new file mode 100644 index 000000000..ed90bc268 --- /dev/null +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Internal/IO/ValueSerializers/VectorSerializerTests.cs @@ -0,0 +1,285 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Linq; +using FluentAssertions; +using Neo4j.Driver.Internal.IO; +using Neo4j.Driver.Internal.IO.ValueSerializers.VectorSerializers; +using Xunit; + +namespace Neo4j.Driver.Tests.Internal.IO.ValueSerializers; + +public class VectorSerializerTests : PackStreamSerializerTests +{ + internal override IPackStreamSerializer SerializerUnderTest => new VectorSerializer(); + + [Fact] + public void ShouldSerializeFloat32Vector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var float32Vector = new[] { 0.1f, 0.2f, 0.3f }; + var vector = Vector.Create(float32Vector); + + writer.Write(vector); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + + reader.PeekNextType().Should().Be(PackStreamType.Struct); + reader.ReadStructHeader().Should().Be(2); // Size of the struct + reader.ReadStructSignature().Should().Be((byte)'V'); // Vector struct type + reader.ReadBytes().Should().BeEquivalentTo([PackStream.Float32]); + reader.ReadBytes().Should().BeEquivalentTo(float32Vector.SelectMany(PackStreamBitConverter.GetBytes).ToArray()); + } + + [Fact] + public void ShouldSerializeFloat64Vector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var float64Vector = new[] { 0.1, 0.2 }; + var vector = Vector.Create(float64Vector); + + writer.Write(vector); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + + reader.PeekNextType().Should().Be(PackStreamType.Struct); + reader.ReadStructHeader().Should().Be(2); // Size of the struct + reader.ReadStructSignature().Should().Be((byte)'V'); // Vector struct type + reader.ReadBytes().Should().BeEquivalentTo([PackStream.Float64]); + reader.ReadBytes().Should().BeEquivalentTo(float64Vector.SelectMany(PackStreamBitConverter.GetBytes).ToArray()); + } + + [Fact] + public void ShouldSerializeByteVector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var byteVector = new sbyte[] { 1, 2, 3 }; + var vector = Vector.Create(byteVector); + + writer.Write(vector); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + + reader.PeekNextType().Should().Be(PackStreamType.Struct); + reader.ReadStructHeader().Should().Be(2); // Size of the struct + reader.ReadStructSignature().Should().Be((byte)'V'); // Vector struct type + reader.ReadBytes().Should().BeEquivalentTo([PackStream.Int8]); + reader.Read().Should().BeEquivalentTo(byteVector); + } + + [Fact] + public void ShouldSerializeInt16Vector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var int16Vector = new short[] { 100, 200, 300 }; + var vector = Vector.Create(int16Vector); + + writer.Write(vector); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + + reader.PeekNextType().Should().Be(PackStreamType.Struct); + reader.ReadStructHeader().Should().Be(2); // Size of the struct + reader.ReadStructSignature().Should().Be((byte)'V'); // Vector struct type + reader.ReadBytes().Should().BeEquivalentTo([PackStream.Int16]); + reader.ReadBytes().Should().BeEquivalentTo(int16Vector.SelectMany(PackStreamBitConverter.GetBytes).ToArray()); + } + + [Fact] + public void ShouldSerializeInt32Vector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var int32Vector = new[] { 1, 2, 3 }; + var vector = Vector.Create(int32Vector); + + writer.Write(vector); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + + reader.PeekNextType().Should().Be(PackStreamType.Struct); + reader.ReadStructHeader().Should().Be(2); // Size of the struct + reader.ReadStructSignature().Should().Be((byte)'V'); // Vector struct type + reader.ReadBytes().Should().BeEquivalentTo([PackStream.Int32]); + reader.ReadBytes().Should().BeEquivalentTo(int32Vector.SelectMany(PackStreamBitConverter.GetBytes).ToArray()); + } + + [Fact] + public void ShouldSerializeInt64Vector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var int64Vector = new[] { 1000L, 2000L, 3000L }; + var vector = Vector.Create(int64Vector); + + writer.Write(vector); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + + reader.PeekNextType().Should().Be(PackStreamType.Struct); + reader.ReadStructHeader().Should().Be(2); // Size of the struct + reader.ReadStructSignature().Should().Be((byte)'V'); // Vector struct type + reader.ReadBytes().Should().BeEquivalentTo([PackStream.Int64]); + reader.ReadBytes().Should().BeEquivalentTo(int64Vector.SelectMany(PackStreamBitConverter.GetBytes).ToArray()); + } + + [Fact] + public void ShouldDeserializeSByteVector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var byteVector = new sbyte[] { 1, 2, 3 }; + + writer.WriteStructHeader(2, (byte)'V'); + writer.WriteByteArray([PackStream.Int8]); + writer.WriteByteArray(byteVector.Select(b => (byte)b).ToArray()); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + var value = reader.Read(); + + value.Should().BeOfType>(); + var vector = (Vector)value; + vector.ElementType.Should().Be(typeof(sbyte)); + vector.Values.Should().BeEquivalentTo(byteVector); + } + + [Fact] + public void ShouldDeserializeInt16Vector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var int16Vector = new short[] { 100, 200, 300 }; + + writer.WriteStructHeader(2, (byte)'V'); + writer.WriteByteArray([PackStream.Int16]); + writer.WriteByteArray(int16Vector.SelectMany(PackStreamBitConverter.GetBytes).ToArray()); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + var value = reader.Read(); + + value.Should().BeOfType>(); + var vector = (Vector)value; + vector.ElementType.Should().Be(typeof(short)); + vector.Values.Should().BeEquivalentTo(int16Vector); + } + + [Fact] + public void ShouldDeserializeInt32Vector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var int32Vector = new[] { 1, 2, 3 }; + + writer.WriteStructHeader(2, (byte)'V'); + writer.WriteByteArray([PackStream.Int32]); + writer.WriteByteArray(int32Vector.SelectMany(PackStreamBitConverter.GetBytes).ToArray()); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + var value = reader.Read(); + + value.Should().BeOfType>(); + var vector = (Vector)value; + vector.ElementType.Should().Be(typeof(int)); + vector.Values.Should().BeEquivalentTo(int32Vector); + } + + [Fact] + public void ShouldDeserializeInt64Vector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var int64Vector = new[] { 1000L, 2000L, 3000L }; + + writer.WriteStructHeader(2, (byte)'V'); + writer.WriteByteArray([PackStream.Int64]); + writer.WriteByteArray(int64Vector.SelectMany(PackStreamBitConverter.GetBytes).ToArray()); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + var value = reader.Read(); + + value.Should().BeOfType>(); + var vector = (Vector)value; + vector.ElementType.Should().Be(typeof(long)); + vector.Values.Should().BeEquivalentTo(int64Vector); + } + + [Fact] + public void ShouldDeserializeFloat32Vector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var float32Vector = new[] { 0.1f, 0.2f, 0.3f }; + + writer.WriteStructHeader(2, (byte)'V'); + writer.WriteByteArray([PackStream.Float32]); + writer.WriteByteArray(float32Vector.SelectMany(PackStreamBitConverter.GetBytes).ToArray()); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + var value = reader.Read(); + + value.Should().BeOfType>(); + var vector = (Vector)value; + vector.ElementType.Should().Be(typeof(float)); + vector.Values.Should().BeEquivalentTo(float32Vector); + } + + [Fact] + public void ShouldDeserializeFloat64Vector() + { + var writerMachine = CreateWriterMachine(); + var writer = writerMachine.Writer; + + var float64Vector = new[] { 0.1, 0.2 }; + + writer.WriteStructHeader(2, (byte)'V'); + writer.WriteByteArray([PackStream.Float64]); + writer.WriteByteArray(float64Vector.SelectMany(PackStreamBitConverter.GetBytes).ToArray()); + + var readerMachine = CreateReaderMachine(writerMachine.GetOutput()); + var reader = readerMachine.Reader(); + var value = reader.Read(); + + value.Should().BeOfType>(); + var vector = (Vector)value; + vector.ElementType.Should().Be(typeof(double)); + vector.Values.Should().BeEquivalentTo(float64Vector); + } +} diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Util/BytesToTypedArrayTests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Util/BytesToTypedArrayTests.cs new file mode 100644 index 000000000..e7f0f0f99 --- /dev/null +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Util/BytesToTypedArrayTests.cs @@ -0,0 +1,214 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using FluentAssertions; +using Neo4j.Driver.Internal.Util; +using Xunit; + +namespace Neo4j.Driver.Tests.Internal.Util; + +public class BytesToTypedArrayHelperTests +{ + [Fact] + public void ConvertBytesToTypedArray_SByte_ReturnsCorrectArray() + { + // Arrange + var bytes = new byte[] { 0x7F, 0x80, 0x00, 0xFF }; // max, min, zero, -1 + + // Act + var result = BytesToTypedArrayHelper.ConvertBytesToTypedArray(bytes, typeof(sbyte)); + + // Assert + result.Should().BeOfType(); + var sbyteArray = (sbyte[])result; + sbyteArray.Should().HaveCount(4); + sbyteArray.Should().Equal(127, -128, 0, -1); + } + + [Fact] + public void ConvertBytesToTypedArray_Short_ReturnsCorrectArray() + { + // Arrange - big-endian bytes for short values + var bytes = new byte[] { 0x7F, 0xFF, 0x80, 0x00, 0x00, 0x00, 0xFF, 0xFF }; // 32767, -32768, 0, -1 + + // Act + var result = BytesToTypedArrayHelper.ConvertBytesToTypedArray(bytes, typeof(short)); + + // Assert + result.Should().BeOfType(); + var shortArray = (short[])result; + shortArray.Should().HaveCount(4); + shortArray.Should().Equal(32767, -32768, 0, -1); + } + + [Fact] + public void ConvertBytesToTypedArray_Int_ReturnsCorrectArray() + { + // Arrange - big-endian bytes for int values + var bytes = new byte[] + { + 0x7F, 0xFF, 0xFF, 0xFF, // 2147483647 + 0x80, 0x00, 0x00, 0x00, // -2147483648 + 0x00, 0x00, 0x00, 0x00, // 0 + 0xFF, 0xFF, 0xFF, 0xFF // -1 + }; + + // Act + var result = BytesToTypedArrayHelper.ConvertBytesToTypedArray(bytes, typeof(int)); + + // Assert + result.Should().BeOfType(); + var intArray = (int[])result; + intArray.Should().HaveCount(4); + intArray.Should().Equal(2147483647, -2147483648, 0, -1); + } + + [Fact] + public void ConvertBytesToTypedArray_Long_ReturnsCorrectArray() + { + // Arrange - big-endian bytes for long values + var bytes = new byte[] + { + 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 9223372036854775807 + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // -9223372036854775808 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 0 + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF // -1 + }; + + // Act + var result = BytesToTypedArrayHelper.ConvertBytesToTypedArray(bytes, typeof(long)); + + // Assert + result.Should().BeOfType(); + var longArray = (long[])result; + longArray.Should().HaveCount(4); + longArray.Should().Equal(9223372036854775807L, -9223372036854775808L, 0L, -1L); + } + + [Fact] + public void ConvertBytesToTypedArray_Float_ReturnsCorrectArray() + { + // Arrange - big-endian bytes for float values (IEEE 754) + var bytes = new byte[] + { + 0x3F, 0x80, 0x00, 0x00, // 1.0f + 0xBF, 0x80, 0x00, 0x00, // -1.0f + 0x00, 0x00, 0x00, 0x00, // 0.0f + 0x42, 0x28, 0x00, 0x00 // 42.0f + }; + + // Act + var result = BytesToTypedArrayHelper.ConvertBytesToTypedArray(bytes, typeof(float)); + + // Assert + result.Should().BeOfType(); + var floatArray = (float[])result; + floatArray.Should().HaveCount(4); + floatArray.Should().Equal(1.0f, -1.0f, 0.0f, 42.0f); + } + + [Fact] + public void ConvertBytesToTypedArray_Double_ReturnsCorrectArray() + { + // Arrange - big-endian bytes for double values (IEEE 754) + var bytes = new byte[] + { + 0x3F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 1.0 + 0xBF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // -1.0 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 0.0 + 0x40, 0x45, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 // 42.0 + }; + + // Act + var result = BytesToTypedArrayHelper.ConvertBytesToTypedArray(bytes, typeof(double)); + + // Assert + result.Should().BeOfType(); + var doubleArray = (double[])result; + doubleArray.Should().HaveCount(4); + doubleArray.Should().Equal(1.0, -1.0, 0.0, 42.0); + } + + [Fact] + public void ConvertBytesToTypedArray_EmptyArray_ReturnsEmptyTypedArray() + { + // Arrange + var bytes = Array.Empty(); + + // Act + var result = BytesToTypedArrayHelper.ConvertBytesToTypedArray(bytes, typeof(int)); + + // Assert + result.Should().BeOfType(); + var intArray = (int[])result; + intArray.Should().BeEmpty(); + } + + [Fact] + public void ConvertBytesToTypedArray_SingleElement_ReturnsCorrectArray() + { + // Arrange + var bytes = new byte[] { 0x00, 0x00, 0x00, 0x2A }; // 42 in big-endian + + // Act + var result = BytesToTypedArrayHelper.ConvertBytesToTypedArray(bytes, typeof(int)); + + // Assert + result.Should().BeOfType(); + var intArray = (int[])result; + intArray.Should().ContainSingle().Which.Should().Be(42); + } + + [Theory] + [InlineData(typeof(sbyte))] + [InlineData(typeof(short))] + [InlineData(typeof(int))] + [InlineData(typeof(long))] + [InlineData(typeof(float))] + [InlineData(typeof(double))] + public void ConvertBytesToTypedArray_ValidTypes_ReturnsCorrectType(Type elementType) + { + // Arrange + var elementSize = System.Runtime.InteropServices.Marshal.SizeOf(elementType); + var bytes = new byte[elementSize * 2]; // 2 elements + + // Act + var result = BytesToTypedArrayHelper.ConvertBytesToTypedArray(bytes, elementType); + + // Assert + result.Should().BeOfType(elementType.MakeArrayType()); + var array = (Array)result; + array.Should().HaveCount(2); + } + + [Fact] + public void ConvertBytesToTypedArray_ModifiesOriginalByteArray() + { + // Arrange + var originalBytes = new byte[] { 0x00, 0x00, 0x00, 0x01 }; // 1 in big-endian + var bytesCopy = (byte[])originalBytes.Clone(); + + // Act + BytesToTypedArrayHelper.ConvertBytesToTypedArray(bytesCopy, typeof(int)); + + // Assert - on little-endian systems, bytes should be reversed + if (BitConverter.IsLittleEndian) + { + bytesCopy.Should().NotEqual(originalBytes); + bytesCopy.Should().Equal(new byte[] { 0x01, 0x00, 0x00, 0x00 }); + } + } +} diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/MappingProviderTests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/MappingProviderTests.cs index 911cbd030..9799e17c9 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/MappingProviderTests.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/MappingProviderTests.cs @@ -102,17 +102,6 @@ public void ShouldNotFailWhenUsingDefaultMapperButMappingSomePropertiesExplicitl obj.Guid.Should().Be(guid); } - [Fact] - public void ShouldFailWhenUsingDefaultMapperWithoutOverriding() - { - var guid = Guid.NewGuid(); - var testRecord = TestRecord.Create(("Name", "Alice"), ("Guid", guid.ToString())); - RecordObjectMapping.RegisterProvider(new MappingProviderThatUsesDefaultMappingAndOverridesAGuidProperty(false)); - - var act = () => testRecord.AsObject(); - act.Should().Throw(); - } - private class TestObject { [MappingSource("intValue")] diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/RecordMappingTests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/RecordMappingTests.cs index 244b0c39a..5d6c340e9 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/RecordMappingTests.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/RecordMappingTests.cs @@ -575,7 +575,7 @@ public async Task MapMethods_ShouldBeThreadSafe() await Task.WhenAll(tasks); // Fail the test if any exceptions were caught - if (exceptions.Count > 0) + if (!exceptions.IsEmpty) { throw new AggregateException("Thread safety issues detected.", exceptions); } diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/VectorMappingTests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/VectorMappingTests.cs new file mode 100644 index 000000000..0a520f6be --- /dev/null +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Mapping/VectorMappingTests.cs @@ -0,0 +1,262 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using FluentAssertions; +using Neo4j.Driver.Internal.Types; +using Neo4j.Driver.Mapping; +using Neo4j.Driver.Tests.TestUtil; +using Xunit; + +namespace Neo4j.Driver.Tests.Mapping; + +// ReSharper disable once ClassNeverInstantiated.Global +public class ClassWithVector +{ + public int Id { get; set; } + public string Name { get; set; } = null!; + public Vector DoubleVector { get; set; } = null!; +} + +public class VectorMappingTests +{ + [Fact] + public void Should_Map_Vector_Property() + { + var record = TestRecord.Create( + ("Id", 1), + ("Name", "Alice"), + ("DoubleVector", new Vector([1.0, 2.0, 3.0]))); + + var poco = record.AsObject(); + + poco.Id.Should().Be(1); + poco.Name.Should().Be("Alice"); + poco.DoubleVector.Values.Should().Equal(1.0, 2.0, 3.0); + } + + [Fact] + public void Should_Throw_When_Vector_Element_Type_Does_Not_Match_Property_Type() + { + var record = TestRecord.Create( + ("Id", 1), + ("Name", "Alice"), + ("DoubleVector", new Vector([1, 2, 3]))); + + var act = () => record.AsObject(); + + act.Should().Throw(); + } + + [Fact] + public void Should_Map_Vector_Nested_In_Anon_Object() + { + var node = new Node( + 1, + ["Label"], + new Dictionary + { + { "Id", 1 }, + { "Name", "Alice" }, + { "DoubleVector", new Vector([1.0, 2.0, 3.0]) } + }); + + var record = TestRecord.Create(("Nested", node)); + var poco = record.AsObjectFromBlueprint( + new + { + Nested = new + { + Id = 0, Name = string.Empty, DoubleVector = (Vector)null! + } + }); + + poco.Nested.Id.Should().Be(1); + poco.Nested.Name.Should().Be("Alice"); + poco.Nested.DoubleVector.Values.Should().Equal(1.0, 2.0, 3.0); + } + + private IRecord GetTestRecord() + { + return TestRecord.Create( + ("Ints", new Vector([5, 6, 7])), + ("Longs", new Vector([3456L, 4567L, 5678L])), + ("Floats", new Vector([394857.5f, 48576.4f, 5768.3f])), + ("Doubles", new Vector([854765.45, 94765.34, 8765.23])), + ("Shorts", new Vector([43, 54, 65])), + ("Bytes", new Vector([94, 85, 76]))); + } + + [Fact] + public void Should_Convert_Vector_To_Array() + { + var record = GetTestRecord(); + var obj = record.AsObjectFromBlueprint( + new + { + Ints = System.Array.Empty(), + Longs = System.Array.Empty(), + Floats = System.Array.Empty(), + Doubles = System.Array.Empty(), + Shorts = System.Array.Empty(), + Bytes = System.Array.Empty() + }); + + obj.Ints.Should().Equal(5, 6, 7); + obj.Longs.Should().Equal(3456L, 4567L, 5678L); + obj.Floats.Should().BeApproximately([394857.5f, 48576.4f, 5768.3f]); + obj.Doubles.Should().BeApproximately([854765.45, 94765.34, 8765.23]); + obj.Shorts.Should().Equal(43, 54, 65); + obj.Bytes.Should().Equal(94, 85, 76); + } + + [Fact] + public void Should_Convert_Vector_To_List() + { + var record = GetTestRecord(); + var obj = record.AsObjectFromBlueprint( + new + { + Ints = new List(), + Longs = new List(), + Floats = new List(), + Doubles = new List(), + Shorts = new List(), + Bytes = new List() + }); + + obj.Ints.Should().Equal(5, 6, 7); + obj.Longs.Should().Equal(3456L, 4567L, 5678L); + obj.Floats.Should().BeApproximately([394857.5f, 48576.4f, 5768.3f]); + obj.Doubles.Should().BeApproximately([854765.45, 94765.34, 8765.23]); + obj.Shorts.Should().Equal(43, 54, 65); + obj.Bytes.Should().Equal(94, 85, 76); + } + + [Fact] + public void Should_Convert_Vector_To_IList() + { + var record = GetTestRecord(); + var obj = record.AsObjectFromBlueprint( + new + { + Ints = (IList)null!, + Longs = (IList)null!, + Floats = (IList)null!, + Doubles = (IList)null!, + Shorts = (IList)null!, + Bytes = (IList)null! + }); + + obj.Ints.Should().Equal(5, 6, 7); + obj.Longs.Should().Equal(3456L, 4567L, 5678L); + obj.Floats.Should().BeApproximately([394857.5f, 48576.4f, 5768.3f]); + obj.Doubles.Should().BeApproximately([854765.45, 94765.34, 8765.23]); + obj.Shorts.Should().Equal(43, 54, 65); + obj.Bytes.Should().Equal(94, 85, 76); + } + + [Fact] + public void Should_Convert_Vector_To_IEnumerable() + { + var record = GetTestRecord(); + var obj = record.AsObjectFromBlueprint( + new + { + Ints = (IEnumerable)null!, + Longs = (IEnumerable)null!, + Floats = (IEnumerable)null!, + Doubles = (IEnumerable)null!, + Shorts = (IEnumerable)null!, + Bytes = (IEnumerable)null! + }); + + obj.Ints.Should().Equal(5, 6, 7); + obj.Longs.Should().Equal(3456L, 4567L, 5678L); + obj.Floats.Should().BeApproximately([394857.5f, 48576.4f, 5768.3f]); + obj.Doubles.Should().BeApproximately([854765.45, 94765.34, 8765.23]); + obj.Shorts.Should().Equal(43, 54, 65); + obj.Bytes.Should().Equal(94, 85, 76); + } + + [Fact] + public void Should_Convert_Vector_To_IReadOnlyList() + { + var record = GetTestRecord(); + var obj = record.AsObjectFromBlueprint( + new + { + Ints = (IReadOnlyList)null!, + Longs = (IReadOnlyList)null!, + Floats = (IReadOnlyList)null!, + Doubles = (IReadOnlyList)null!, + Shorts = (IReadOnlyList)null!, + Bytes = (IReadOnlyList)null! + }); + + obj.Ints.Should().Equal(5, 6, 7); + obj.Longs.Should().Equal(3456L, 4567L, 5678L); + obj.Floats.Should().BeApproximately([394857.5f, 48576.4f, 5768.3f]); + obj.Doubles.Should().BeApproximately([854765.45, 94765.34, 8765.23]); + obj.Shorts.Should().Equal(43, 54, 65); + obj.Bytes.Should().Equal(94, 85, 76); + } + + [Fact] + public void Should_Convert_Vector_To_IReadOnlyCollection() + { + var record = GetTestRecord(); + var obj = record.AsObjectFromBlueprint( + new + { + Ints = (IReadOnlyCollection)null!, + Longs = (IReadOnlyCollection)null!, + Floats = (IReadOnlyCollection)null!, + Doubles = (IReadOnlyCollection)null!, + Shorts = (IReadOnlyCollection)null!, + Bytes = (IReadOnlyCollection)null! + }); + + obj.Ints.Should().Equal(5, 6, 7); + obj.Longs.Should().Equal(3456L, 4567L, 5678L); + obj.Floats.Should().BeApproximately([394857.5f, 48576.4f, 5768.3f]); + obj.Doubles.Should().BeApproximately([854765.45, 94765.34, 8765.23]); + obj.Shorts.Should().Equal(43, 54, 65); + obj.Bytes.Should().Equal(94, 85, 76); + } + + [Fact] + public void Should_Convert_Vector_To_ICollection() + { + var record = GetTestRecord(); + var obj = record.AsObjectFromBlueprint( + new + { + Ints = (ICollection)null!, + Longs = (ICollection)null!, + Floats = (ICollection)null!, + Doubles = (ICollection)null!, + Shorts = (ICollection)null!, + Bytes = (ICollection)null! + }); + + obj.Ints.Should().Equal(5, 6, 7); + obj.Longs.Should().Equal(3456L, 4567L, 5678L); + obj.Floats.Should().BeApproximately([394857.5f, 48576.4f, 5768.3f]); + obj.Doubles.Should().BeApproximately([854765.45, 94765.34, 8765.23]); + obj.Shorts.Should().Equal(43, 54, 65); + obj.Bytes.Should().Equal(94, 85, 76); + } +} diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/TestUtil/Assertions.cs b/Neo4j.Driver/Neo4j.Driver.Tests/TestUtil/Assertions.cs index 3368347e9..d9cb66417 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests/TestUtil/Assertions.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests/TestUtil/Assertions.cs @@ -14,7 +14,9 @@ // limitations under the License. using System; +using System.Linq; using FluentAssertions; +using FluentAssertions.Collections; using FluentAssertions.Execution; using FluentAssertions.Numeric; @@ -61,4 +63,28 @@ public static Func Matches(Action assertion) return true; }; } + + public static AndConstraint> BeApproximately( + this GenericCollectionAssertions assertions, + double[] expected, + double precision = 0.0001) + { + return assertions.HaveCount(expected.Length) + .And.SatisfyRespectively( + expected.Select(e => new Action(actual => + Convert.ToDouble(actual).Should().BeApproximately(Convert.ToDouble(e), precision))) + .ToArray()); + } + + public static AndConstraint> BeApproximately( + this GenericCollectionAssertions assertions, + float[] expected, + float precision = 0.0001F) + { + return assertions.HaveCount(expected.Length) + .And.SatisfyRespectively( + expected.Select(e => new Action(actual => + Convert.ToSingle(actual).Should().BeApproximately(Convert.ToSingle(e), precision))) + .ToArray()); + } } diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Types/VectorTests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Types/VectorTests.cs new file mode 100644 index 000000000..5613006a1 --- /dev/null +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Types/VectorTests.cs @@ -0,0 +1,94 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using FluentAssertions; +using Xunit; + +namespace Neo4j.Driver.Tests.Types; + +public class VectorTests +{ + [Fact] + public void ShouldNotThrowForFloatType() + { + Action act = () => _ = new Vector([1.0f, 2.0f, 3.0f]); + act.Should().NotThrow(); + } + + [Fact] + public void ShouldNotThrowForSByteType() + { + Action act = () => _ = new Vector([1, 2, 3]); + act.Should().NotThrow(); + } + + [Fact] + public void ShouldNotThrowForShortType() + { + Action act = () => _ = new Vector([1, 2, 3]); + act.Should().NotThrow(); + } + + [Fact] + public void ShouldNotThrowForIntType() + { + Action act = () => _ = new Vector([1, 2, 3]); + act.Should().NotThrow(); + } + + [Fact] + public void ShouldNotThrowForLongType() + { + Action act = () => _ = new Vector([1, 2, 3]); + act.Should().NotThrow(); + } + + [Fact] + public void ShouldNotThrowForDoubleType() + { + Action act = () => _ = new Vector([1.0, 2.0, 3.0]); + act.Should().NotThrow(); + } + + [Fact] + public void ShouldThrowForByteType() + { + Action act = () => _ = new Vector(); + act.Should().Throw(); + } + + [Fact] + public void ShouldThrowForBoolType() + { + Action act = () => _ = new Vector(); + act.Should().Throw(); + } + + [Fact] + public void ShouldInitializeVectorWithValues() + { + var values = new[] { 1, 2, 3 }; + var vector = new Vector(values); + + Assert.Equal(values, vector.Values); + } + + [Fact] + public void ShouldThrowForNull() + { + Assert.Throws(() => _ = new Vector(null)); + } +} diff --git a/Neo4j.Driver/Neo4j.Driver.sln b/Neo4j.Driver/Neo4j.Driver.sln index a8d001ea1..118f6586c 100644 --- a/Neo4j.Driver/Neo4j.Driver.sln +++ b/Neo4j.Driver/Neo4j.Driver.sln @@ -27,6 +27,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Neo4j.Driver.Tests.TestBack EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Neo4j.Driver.Tests.BenchkitBackend", "Neo4j.Driver.Tests.BenchkitBackend\Neo4j.Driver.Tests.BenchkitBackend.csproj", "{EFFB6047-6BD4-4CEE-9D5B-09DFE8B8340E}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Neo4j.Vector.Examples", "Neo4j.Vector.Examples\Neo4j.Vector.Examples.csproj", "{2ACFD238-3125-4161-A31C-72BE42CA4E76}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -91,6 +93,14 @@ Global {EFFB6047-6BD4-4CEE-9D5B-09DFE8B8340E}.Release|Any CPU.Build.0 = Release|Any CPU {EFFB6047-6BD4-4CEE-9D5B-09DFE8B8340E}.ReleaseSigned|Any CPU.ActiveCfg = Debug|Any CPU {EFFB6047-6BD4-4CEE-9D5B-09DFE8B8340E}.ReleaseSigned|Any CPU.Build.0 = Debug|Any CPU + {2ACFD238-3125-4161-A31C-72BE42CA4E76}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2ACFD238-3125-4161-A31C-72BE42CA4E76}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2ACFD238-3125-4161-A31C-72BE42CA4E76}.DebugDelaySigned|Any CPU.ActiveCfg = Debug|Any CPU + {2ACFD238-3125-4161-A31C-72BE42CA4E76}.DebugDelaySigned|Any CPU.Build.0 = Debug|Any CPU + {2ACFD238-3125-4161-A31C-72BE42CA4E76}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2ACFD238-3125-4161-A31C-72BE42CA4E76}.Release|Any CPU.Build.0 = Release|Any CPU + {2ACFD238-3125-4161-A31C-72BE42CA4E76}.ReleaseSigned|Any CPU.ActiveCfg = Debug|Any CPU + {2ACFD238-3125-4161-A31C-72BE42CA4E76}.ReleaseSigned|Any CPU.Build.0 = Debug|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Extensions/CollectionExtensions.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Extensions/CollectionExtensions.cs index 4918be423..baf49789b 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Extensions/CollectionExtensions.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Extensions/CollectionExtensions.cs @@ -139,6 +139,20 @@ public static IDictionary ToDictionary(this object o) return FillDictionary(o, new Dictionary()); } + public static Type GetItemType(this IList list) + { + // Check if the list is a generic type + var type = list.GetType(); + if (type.IsGenericType) + { + // Get the generic type argument (e.g., T in List) + return type.GetGenericArguments()[0]; + } + + // If not generic, then object will do + return typeof(object); + } + private static bool TryGetDictionaryOfStringKeys(object o, out IDictionary dictionary) { dictionary = null; @@ -148,10 +162,9 @@ private static bool TryGetDictionaryOfStringKeys(object o, out IDictionary var interfaces = typeInfo.ImplementedInterfaces; - var canUse = interfaces.Any( - i => i.IsGenericType && - i.GetGenericTypeDefinition() == typeof(IDictionary<,>) && - i.GenericTypeArguments[0] == typeof(string)); + var canUse = interfaces.Any(i => i.IsGenericType && + i.GetGenericTypeDefinition() == typeof(IDictionary<,>) && + i.GenericTypeArguments[0] == typeof(string)); if (canUse) { @@ -378,7 +391,8 @@ public bool TryGetValue(string key, out object value) return false; } - public void Add(KeyValuePair item) => throw new NotSupportedException("This dictionary is read-only."); + public void Add(KeyValuePair item) => + throw new NotSupportedException("This dictionary is read-only."); public void Clear() => throw new NotSupportedException("This dictionary is read-only."); diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStream.cs b/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStream.cs index 0c17ef491..53efd3d52 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStream.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStream.cs @@ -27,6 +27,7 @@ internal static class PackStream public const byte TinyStruct = 0xB0; public const byte Null = 0xC0; public const byte Float64 = 0xC1; + public const byte Float32 = 0xC6; public const byte False = 0xC2; public const byte True = 0xC3; public const byte ReservedC4 = 0xC4; diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamBitConverter.cs b/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamBitConverter.cs index d56a4266e..c33723535 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamBitConverter.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamBitConverter.cs @@ -75,6 +75,15 @@ public static byte[] GetBytes(long value) return ToTargetEndian(bytes); } + /// Converts a float (Float32) to bytes. + /// The float (Float32) value to convert. + /// The specified float (Float32) value as an array of bytes.Converts an int (double) to bytes. /// The int (double) value to convert. /// The specified int (double) value as an array of bytes. @@ -92,7 +101,15 @@ public static byte[] GetBytes(string value) return Encoding.UTF8.GetBytes(value); } - /// Converts an byte array to a short. + /// Converts a byte array to a sbyte. + /// The byte array to convert. + /// A sbyte converted from the byte array. + public static sbyte ToSByte(byte[] bytes) + { + return unchecked((sbyte)bytes[0]); + } + + /// Converts a byte array to a short. /// The byte array to convert. /// A short converted from the byte array. public static short ToInt16(byte[] bytes) @@ -101,43 +118,49 @@ public static short ToInt16(byte[] bytes) return BitConverter.ToInt16(bytes, 0); } - /// Converts an byte array to a unsigned short. + /// Converts a byte array to an unsigned short. /// The byte array to convert. - /// A unsigned short converted from the byte array. + /// An unsigned short converted from the byte array. public static ushort ToUInt16(byte[] bytes) { bytes = ToPlatformEndian(bytes); return BitConverter.ToUInt16(bytes, 0); } - /// Converts an byte array to a int (Int32). + /// Converts a byte array to an int (Int32). /// The byte array to convert. - /// A int (Int32) converted from the byte array. + /// An int (Int32) converted from the byte array. public static int ToInt32(byte[] bytes) { bytes = ToPlatformEndian(bytes); return BitConverter.ToInt32(bytes, 0); } - /// Converts an byte array to a int (Int64). + /// Converts a byte array to an int (Int64). /// The byte array to convert. - /// A int (Int64) converted from the byte array. + /// An int (Int64) converted from the byte array. public static long ToInt64(byte[] bytes) { bytes = ToPlatformEndian(bytes); return BitConverter.ToInt64(bytes, 0); } - /// Converts an byte array to a int (double). + /// Converts a byte array to a double (Float64). /// The byte array to convert. - /// A int (double) converted from the byte array. + /// A double (Float64) converted from the byte array. public static double ToDouble(byte[] bytes) { bytes = ToPlatformEndian(bytes); return BitConverter.ToDouble(bytes, 0); } - /// Converts an byte array of a UTF8 encoded string to a string + public static float ToFloat(byte[] bytes) + { + bytes = ToPlatformEndian(bytes); + return BitConverter.ToSingle(bytes, 0); + } + + /// Converts a byte array of a UTF8 encoded string to a string /// The byte array to convert. /// A string converted from the byte array public static string ToString(byte[] bytes) @@ -148,7 +171,7 @@ public static string ToString(byte[] bytes) /// Converts the bytes to big endian. /// The bytes to convert. /// The bytes converted to big endian. - private static byte[] ToTargetEndian(byte[] bytes) + public static byte[] ToTargetEndian(byte[] bytes) { if (BitConverter.IsLittleEndian) { diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamReader.cs b/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamReader.cs index 668c77b40..30d65e9b7 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamReader.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamReader.cs @@ -17,6 +17,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.Reflection; using Neo4j.Driver.Internal.Connector; using Neo4j.Driver.Internal.Messaging; using Neo4j.Driver.Internal.Protocol; @@ -77,7 +78,7 @@ public Dictionary ReadMap() return map; } - public IList ReadList() + public object ReadList() { var size = (int)ReadListHeader(); var vals = new object[size]; diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamWriter.cs b/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamWriter.cs index c5553b307..45aaf74c6 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamWriter.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/IO/PackStreamWriter.cs @@ -51,7 +51,7 @@ public void Write(object value) break; case byte byteValue: - WriteLong(Convert.ToInt64(byteValue)); + WriteByte(byteValue); break; case short shortValue: @@ -90,6 +90,10 @@ public void Write(object value) WriteString(stringValue); break; + case var _ when _format.TryGetWriteStructHandler(value.GetType(), out var structHandler): + structHandler.Serialize(_format.Version, this, value); + break; + case IList list: WriteList(list); break; @@ -107,17 +111,8 @@ public void Write(object value) break; default: - if (_format.WriteStructHandlers.TryGetValue(value.GetType(), out var structHandler)) - { - structHandler.Serialize(_format.Version, this, value); - } - else - { throw new ProtocolException( $"Cannot understand {nameof(value)} with type {value.GetType().FullName}"); - } - - break; } } @@ -126,6 +121,11 @@ private void WriteMessage(IMessage message) message.Serializer.Serialize(_format.Version, this, message); } + public void WriteByte(byte byteValue) + { + WriteLong(Convert.ToInt64(byteValue)); + } + public void WriteInt(int value) { WriteLong(value); @@ -228,7 +228,7 @@ public void WriteList(IList value) WriteListHeader(value.Count); foreach (var item in value) { - Write(item); + Write(item); } } } @@ -289,7 +289,7 @@ public void WriteNull() _stream.WriteByte(Null); } - private void WriteRaw(byte[] data) + public void WriteRaw(byte[] data) { _stream.Write(data); } diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/IO/SpanPackStreamReader.cs b/Neo4j.Driver/Neo4j.Driver/Internal/IO/SpanPackStreamReader.cs index 592077948..6407dd5b4 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/IO/SpanPackStreamReader.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/IO/SpanPackStreamReader.cs @@ -15,7 +15,9 @@ using System; using System.Buffers.Binary; +using System.Collections; using System.Collections.Generic; +using System.Reflection; using System.Runtime.CompilerServices; using System.Text; using Neo4j.Driver.Internal.Messaging; @@ -97,7 +99,7 @@ internal PackStreamType PeekNextType() }; } - private IList ReadList(int length) + private IList ReadList(int length) { var list = new List(length); for (var i = 0; i < length; i++) @@ -144,7 +146,7 @@ private Dictionary ReadMap(int size) return map; } - public IList ReadList() + public IList ReadList() { var length = ReadListHeader(); return ReadList(length); diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/IO/ValueSerializers/VectorSerializers/ITypedVectorSerialisationHelper.cs b/Neo4j.Driver/Neo4j.Driver/Internal/IO/ValueSerializers/VectorSerializers/ITypedVectorSerialisationHelper.cs new file mode 100644 index 000000000..c830380fb --- /dev/null +++ b/Neo4j.Driver/Neo4j.Driver/Internal/IO/ValueSerializers/VectorSerializers/ITypedVectorSerialisationHelper.cs @@ -0,0 +1,22 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Neo4j.Driver.Internal.IO.ValueSerializers.VectorSerializers; + +public interface ITypedVectorSerialisationHelper +{ + byte TypeMarker { get; } + +} diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/IO/ValueSerializers/VectorSerializers/VectorSerializer.cs b/Neo4j.Driver/Neo4j.Driver/Internal/IO/ValueSerializers/VectorSerializers/VectorSerializer.cs new file mode 100644 index 000000000..726b4039d --- /dev/null +++ b/Neo4j.Driver/Neo4j.Driver/Internal/IO/ValueSerializers/VectorSerializers/VectorSerializer.cs @@ -0,0 +1,128 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using Neo4j.Driver.Internal.Protocol; +using Neo4j.Driver.Internal.Util; + +namespace Neo4j.Driver.Internal.IO.ValueSerializers.VectorSerializers; + +internal class VectorSerializer : IPackStreamSerializer +{ + public static VectorSerializer Instance { get; } = new (); + + public const byte VectorStructType = (byte)'V'; + private const int VectorStructSize = 2; + + /// + public byte[] ReadableStructs => [VectorStructType]; + + /// + public IEnumerable WritableTypes => [typeof(Vector)]; + + /// + public object Deserialize(BoltProtocolVersion version, PackStreamReader reader, byte signature, long size) + { + if(signature != VectorStructType) + { + throw new ProtocolException( + $"Unsupported struct signature {signature} passed to {nameof(VectorSerializer)}!"); + } + + PackStream.EnsureStructSize("Vector", VectorStructSize, size); + var typeMarker = reader.ReadBytes()[0]; + if (!MarkerToType.TryGetValue(typeMarker, out var elementType)) + { + throw new ProtocolException($"Unsupported vector element type marker 0x{typeMarker:X2}."); + } + + var byteArray = reader.ReadBytes(); + var originalByteStream = byteArray.ToArray(); + var typedArray = BytesToTypedArrayHelper.ConvertBytesToTypedArray(byteArray, elementType); + return Vector.CreateDynamic(typedArray, originalByteStream); + } + + public static byte[] GetByteStream(IVector vector) + { + var byteConverter = GetByteConverter(vector.ElementType); + var byteArray = vector.UntypedValues.Select(byteConverter).ToArray(); + var flattened = byteArray.SelectMany(b => b).ToArray(); + return flattened; + } + + public void Serialize(BoltProtocolVersion version, PackStreamWriter writer, object value) + { + var vector = value.CastOrThrow(); + writer.WriteStructHeader(VectorStructSize, VectorStructType); + + // the type marker is next + writer.WriteByteArray([TypeToMarker[vector.ElementType]]); + + // then all the values + var byteStream = GetByteStream(vector); + writer.WriteByteArray(byteStream); + } + + /// + public (object, int) DeserializeSpan(BoltProtocolVersion version, SpanPackStreamReader reader, byte signature, int size) + { + if (signature != VectorStructType) + { + throw new ProtocolException( + $"Unsupported struct signature {signature} passed to {nameof(VectorSerializer)}!"); + } + + PackStream.EnsureStructSize("Vector", VectorStructSize, size); + var typeMarker = reader.ReadBytes()[0]; + if (!MarkerToType.TryGetValue(typeMarker, out var elementType)) + { + throw new ProtocolException($"Unsupported vector element type marker 0x{typeMarker:X2}."); + } + + var byteArray = reader.ReadBytes(); + var originalByteStream = byteArray.ToArray(); + var typedArray = BytesToTypedArrayHelper.ConvertBytesToTypedArray(byteArray, elementType); + return (Vector.CreateDynamic(typedArray, originalByteStream), reader.Index); + } + + private static readonly Dictionary TypeToMarker = new() + { + { typeof(sbyte), PackStream.Int8 }, + { typeof(short), PackStream.Int16 }, + { typeof(int), PackStream.Int32 }, + { typeof(long), PackStream.Int64 }, + { typeof(float), PackStream.Float32 }, + { typeof(double), PackStream.Float64 } + }; + + private static readonly Dictionary MarkerToType = + TypeToMarker.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); + + private static Func GetByteConverter(Type type) + { + return type switch + { + _ when type == typeof(sbyte) => value => PackStreamBitConverter.GetBytes(unchecked((byte)(sbyte)value)), + _ when type == typeof(short) => value => PackStreamBitConverter.GetBytes((short)value), + _ when type == typeof(int) => value => PackStreamBitConverter.GetBytes((int)value), + _ when type == typeof(long) => value => PackStreamBitConverter.GetBytes((long)value), + _ when type == typeof(float) => value => PackStreamBitConverter.GetBytes((float)value), + _ when type == typeof(double) => value => PackStreamBitConverter.GetBytes((double)value), + _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unsupported vector element type {type}.") + }; + } +} diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/DefaultConverters.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/DefaultConverters.cs new file mode 100644 index 000000000..efbfcb790 --- /dev/null +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/DefaultConverters.cs @@ -0,0 +1,102 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace Neo4j.Driver.Internal.Mapping.TypeConversion; + +internal interface IDefaultConverters +{ + void Register(); +} + +internal class DefaultConverters : IDefaultConverters +{ + private readonly IMappingTypeConversionManager _manager; + + public DefaultConverters(IMappingTypeConversionManager manager) + { + _manager = manager ?? throw new ArgumentNullException(nameof(manager)); + } + + public void Register() + { + // string to Guid + _manager.RegisterConverter(Guid.Parse); + + // various vector conversions + RegisterVectorConverters(); + } + + private void RegisterVectorConverters() + { + foreach (var t in Vector.SupportedTypes) + { + RegisterSingleVectorTypeConverters(t); + } + } + + private void RegisterSingleVectorTypeConverters(Type elementType) + { + var vectorType = typeof(Vector<>).MakeGenericType(elementType); + var arrayType = elementType.MakeArrayType(); + + var toArrayMethod = + typeof(DefaultConverters).GetMethod(nameof(VectorToArray), BindingFlags.Static | BindingFlags.NonPublic)! + .MakeGenericMethod(elementType); + + // Create delegate from MethodInfo + var arrayConverterType = typeof(Func<,>).MakeGenericType(vectorType, arrayType); + var arrayConverter = Delegate.CreateDelegate(arrayConverterType, toArrayMethod); + + GetRegisterMethod(vectorType, arrayType).Invoke(_manager, [arrayConverter]); + + var toListMethod = + typeof(DefaultConverters).GetMethod(nameof(VectorToList), BindingFlags.Static | BindingFlags.NonPublic)! + .MakeGenericMethod(elementType); + + // Create delegate from MethodInfo + var listType = typeof(List<>).MakeGenericType(elementType); + var listConverterType = typeof(Func<,>).MakeGenericType(vectorType, listType); + var listConverter = Delegate.CreateDelegate(listConverterType, toListMethod); + + foreach (var targetType in GetAllVectorConversionTargets(elementType)) + { + GetRegisterMethod(vectorType, targetType).Invoke(_manager, [listConverter]); + } + } + private static IEnumerable GetAllVectorConversionTargets(Type vectorType) + { + yield return typeof(List<>).MakeGenericType(vectorType); + yield return typeof(IList<>).MakeGenericType(vectorType); + yield return typeof(IEnumerable<>).MakeGenericType(vectorType); + yield return typeof(IReadOnlyList<>).MakeGenericType(vectorType); + yield return typeof(IReadOnlyCollection<>).MakeGenericType(vectorType); + yield return typeof(ICollection<>).MakeGenericType(vectorType); + } + + private MethodInfo GetRegisterMethod(Type fromType, Type toType) + { + var method = typeof(IMappingTypeConversionManager).GetMethod(nameof( + IMappingTypeConversionManager.RegisterConverter)); + return method!.MakeGenericMethod(fromType, toType); + } + + private static T[] VectorToArray(Vector vector) where T : struct => vector.Values.ToArray(); + private static List VectorToList(Vector vector) where T : struct => vector.Values.ToList(); +} diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/IMappingTypeConversionManager.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/IMappingTypeConversionManager.cs index 0b28e52df..f2f597490 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/IMappingTypeConversionManager.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/IMappingTypeConversionManager.cs @@ -23,4 +23,5 @@ internal interface IMappingTypeConversionManager void RegisterConverter(Func converter); bool TryConvert(Type fromType, Type toType, object from, out object to); void Clear(); + void RegisterDefaultConverters(); } diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/MappingTypeConversionManager.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/MappingTypeConversionManager.cs index 743872a9c..531a2118c 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/MappingTypeConversionManager.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Mapping/TypeConversion/MappingTypeConversionManager.cs @@ -25,6 +25,11 @@ internal class MappingTypeConversionManager : IMappingTypeConversionManager public void Clear() => _converters.Clear(); + /// + public void RegisterDefaultConverters() + { + } + public bool TryConvert(Type fromType, Type toType, object from, out object to) { if (_converters.TryGetValue((fromType, toType), out var converter)) diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolFactory.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolFactory.cs index a355dc4e7..fb36f0acb 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolFactory.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolFactory.cs @@ -43,21 +43,23 @@ internal class BoltProtocolFactory : IBoltProtocolFactory public static readonly BoltProtocolVersion[] SupportedVersions = new BoltProtocolVersion[] { - BoltProtocolVersion.V3_0, - BoltProtocolVersion.V4_0, - BoltProtocolVersion.V4_1, - BoltProtocolVersion.V4_2, - BoltProtocolVersion.V4_3, - BoltProtocolVersion.V4_4, - BoltProtocolVersion.V5_0, - BoltProtocolVersion.V5_1, - BoltProtocolVersion.V5_2, - BoltProtocolVersion.V5_3, - BoltProtocolVersion.V5_4, - BoltProtocolVersion.V5_5, - BoltProtocolVersion.V5_6, - BoltProtocolVersion.V5_7, - BoltProtocolVersion.V5_8 + //NOTE: CHANGE WHEN ADDING A BOLT PROTOCOL VERSION + BoltProtocolVersion.V3_0 + ,BoltProtocolVersion.V4_0 + ,BoltProtocolVersion.V4_1 + ,BoltProtocolVersion.V4_2 + ,BoltProtocolVersion.V4_3 + ,BoltProtocolVersion.V4_4 + ,BoltProtocolVersion.V5_0 + ,BoltProtocolVersion.V5_1 + ,BoltProtocolVersion.V5_2 + ,BoltProtocolVersion.V5_3 + ,BoltProtocolVersion.V5_4 + ,BoltProtocolVersion.V5_5 + ,BoltProtocolVersion.V5_6 + ,BoltProtocolVersion.V5_7 + ,BoltProtocolVersion.V5_8 + ,BoltProtocolVersion.V6_0 }; private static readonly Lazy HandshakeBytesLazy = @@ -71,11 +73,14 @@ internal class BoltProtocolFactory : IBoltProtocolFactory var versions = new[] { goGoBolt, - + + // List of protocol versions to offer in the legacy handshake (must be exactly 4 entries: manifest marker + 3 offers) + // Update these when dropping/adding legacy support. Do NOT add new major versions here (e.g. 6.0, 7.0, etc). + //Announce support for the new handshake format with no manifest range supplied. - BoltProtocolVersion.HandshakeManifestV1.PackToInt(), - - // 3 more versions max. + BoltProtocolVersion.HandshakeManifestV1.PackToInt(), + + //Legacy Handshake version 3 more versions max. Do not add newer versions in. BoltProtocolVersion.V5_8.PackToIntRange(BoltProtocolVersion.V5_0), BoltProtocolVersion.V4_4.PackToIntRange(BoltProtocolVersion.V4_2), BoltProtocolVersion.V3_0.PackToInt() @@ -99,15 +104,17 @@ public IBoltProtocol ForVersion(BoltProtocolVersion version) return version switch { // no matching versions + //NOTE: CHANGE WHEN ADDING A BOLT PROTOCOL VERSION { MajorVersion: 0, MinorVersion: 0 } => throw new NotSupportedException(NoAgreedVersion), { MajorVersion: 3, MinorVersion: 0 } => BoltProtocolV3.Instance, { MajorVersion: 4, MinorVersion: <= 4, MinorVersion: >= 1 } => BoltProtocol.Instance, { MajorVersion: 5, MinorVersion: <= 8, MinorVersion: >= 0 } => BoltProtocol.Instance, + { MajorVersion: 6, MinorVersion: >= 0, MinorVersion: <= 0 } => BoltProtocol.Instance, _ => throw new NotSupportedException( $"Protocol error, server suggested unexpected protocol version: {version}") }; } - + public static (BoltProtocolVersion version, int range) UnpackAgreedVersion(byte[] data) { var packedInt = PackStreamBitConverter.ToInt32(data); diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolVersion.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolVersion.cs index eff0bbf31..412c3197c 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolVersion.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolVersion.cs @@ -31,6 +31,7 @@ internal sealed class BoltProtocolVersion : IEquatable, ICo public static readonly BoltProtocolVersion Unknown = new(1, 0); // ReSharper disable InconsistentNaming + //NOTE: CHANGE WHEN ADDING A BOLT PROTOCOL VERSION public static readonly BoltProtocolVersion V3_0 = new(3, 0); public static readonly BoltProtocolVersion V4_0 = new(4, 0); public static readonly BoltProtocolVersion V4_1 = new(4, 1); @@ -46,8 +47,10 @@ internal sealed class BoltProtocolVersion : IEquatable, ICo public static readonly BoltProtocolVersion V5_6 = new(5, 6); public static readonly BoltProtocolVersion V5_7 = new(5, 7); public static readonly BoltProtocolVersion V5_8 = new(5, 8); + public static readonly BoltProtocolVersion V6_0 = new(6, 0); - public static readonly BoltProtocolVersion LatestVersion = V5_8; + //NOTE: CHANGE WHEN ADDING A BOLT PROTOCOL VERSION + public static readonly BoltProtocolVersion LatestVersion = V6_0; public static readonly BoltProtocolVersion HandshakeManifestV1 = new(ManifestSchema, ManifestVersion); // ReSharper restore InconsistentNaming @@ -72,7 +75,7 @@ public BoltProtocolVersion(int majorVersion, int minorVersion) public BoltProtocolVersion(int largeVersion) { - //This version of the constructor is only to be used to handle error codes that come in that are not strictly containing packed values. + //This version of the constructor is only to be used to handle error codes that come in that are not strictly containing packed values. MajorVersion = UnpackMajor(largeVersion); MinorVersion = UnpackMinor(largeVersion); _compValue = MajorVersion * 1000000 + MinorVersion; diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/MessageFormat.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/MessageFormat.cs index 44d8f23fe..9d23111c4 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/MessageFormat.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/MessageFormat.cs @@ -20,6 +20,7 @@ using Neo4j.Driver.Internal.IO.MessageSerializers; using Neo4j.Driver.Internal.IO.ValueSerializers; using Neo4j.Driver.Internal.IO.ValueSerializers.Temporal; +using Neo4j.Driver.Internal.IO.ValueSerializers.VectorSerializers; using Neo4j.Driver.Internal.Messaging; namespace Neo4j.Driver.Internal.Protocol; @@ -113,6 +114,12 @@ internal MessageFormat(BoltProtocolVersion version, DriverContext context) AddHandler(ElementRelationshipSerializer.Instance); AddHandler(ElementUnboundRelationshipSerializer.Instance); } + + if(Version >= BoltProtocolVersion.V6_0) + { + // vectors in 6.0 + + AddHandler(VectorSerializer.Instance); + } } // Test code. @@ -146,6 +153,21 @@ internal MessageFormat( public BoltProtocolVersion Version { get; } public IReadOnlyDictionary MessageReaders => _messageReaders; + public bool TryGetWriteStructHandler(Type type, out IPackStreamSerializer handler) + { + foreach (var kvp in _writerStructHandlers) + { + if (kvp.Key.IsAssignableFrom(type)) + { + handler = kvp.Value; + return true; + } + } + + handler = null; + return false; + } + private void AddMessageHandler(T instance) where T : class, IPackStreamMessageDeserializer, IPackStreamSerializer { _messageReaders.Add(instance.ReadableStructs[0], instance); diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Util/BytesToTypedArrayHelper.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Util/BytesToTypedArrayHelper.cs new file mode 100644 index 000000000..bc6c13e10 --- /dev/null +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Util/BytesToTypedArrayHelper.cs @@ -0,0 +1,59 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Concurrent; +using System.Reflection; +using System.Runtime.InteropServices; + +namespace Neo4j.Driver.Internal.Util; + +internal class BytesToTypedArrayHelper +{ + private static readonly ConcurrentDictionary> Converters = new(); + + public static Array ConvertBytesToTypedArray(byte[] bytes, Type elementType) + { + // Deal with endianness + if (BitConverter.IsLittleEndian) + { + var elementSize = Marshal.SizeOf(elementType); + for (var i = 0; i < bytes.Length; i += elementSize) + { + Array.Reverse(bytes, i, elementSize); + } + } + + var converter = Converters.GetOrAdd(elementType, CreateConverter); + return converter(bytes); + } + + private static Array CreateTypedArrayFromBytes(byte[] bytes) where T : unmanaged + { + var span = bytes.AsSpan(); + var typedSpan = MemoryMarshal.Cast(span); + return typedSpan.ToArray(); + } + + private static Func CreateConverter(Type elementType) + { + var method = typeof(BytesToTypedArrayHelper).GetMethod( + nameof(CreateTypedArrayFromBytes), + BindingFlags.NonPublic | BindingFlags.Static)! + .MakeGenericMethod(elementType); + + return bytes => (Array)method.Invoke(null, [bytes])!; + } +} diff --git a/Neo4j.Driver/Neo4j.Driver/Properties/AssemblyInfo.cs b/Neo4j.Driver/Neo4j.Driver/Properties/AssemblyInfo.cs index 1e91d5a0a..1b88074ca 100644 --- a/Neo4j.Driver/Neo4j.Driver/Properties/AssemblyInfo.cs +++ b/Neo4j.Driver/Neo4j.Driver/Properties/AssemblyInfo.cs @@ -44,6 +44,7 @@ [assembly: InternalsVisibleTo("Neo4j.Driver.Reactive")] [assembly: InternalsVisibleTo("Neo4j.Driver.Simple")] [assembly: InternalsVisibleTo("Neo4j.Driver.Tests")] +[assembly: InternalsVisibleTo("Neo4j.Vector.Examples")] [assembly: InternalsVisibleTo("Neo4j.Driver.Tests.Integration")] [assembly: InternalsVisibleTo("Neo4j.Driver.Tests.TestBackend")] // Required for Moq to function in Unit Tests diff --git a/Neo4j.Driver/Neo4j.Driver/Public/Mapping/RecordObjectMapping.cs b/Neo4j.Driver/Neo4j.Driver/Public/Mapping/RecordObjectMapping.cs index 43558cf59..5cad8ee68 100644 --- a/Neo4j.Driver/Neo4j.Driver/Public/Mapping/RecordObjectMapping.cs +++ b/Neo4j.Driver/Neo4j.Driver/Public/Mapping/RecordObjectMapping.cs @@ -55,10 +55,13 @@ public class RecordObjectMapping : IRecordObjectMapping private readonly ConcurrentDictionary _mapMethods = new(); private readonly ConcurrentDictionary _mappers = new(); private readonly IMappingTypeConversionManager _typeConversionManager = new MappingTypeConversionManager(); + private IDefaultConverters _defaultConverters; private IConventionTranslator _conventionTranslator = new NoOpConventionTranslator(); private RecordObjectMapping() { + _defaultConverters = new DefaultConverters(_typeConversionManager); + _defaultConverters.Register(); } internal static readonly RecordObjectMapping Instance = new(); @@ -208,6 +211,8 @@ internal static void Reset() Instance._mapMethods.Clear(); Instance._typeConversionManager.Clear(); Instance._conventionTranslator = new NoOpConventionTranslator(); + Instance._defaultConverters = new DefaultConverters(Instance._typeConversionManager); + Instance._defaultConverters.Register(); DefaultMapper.Reset(); } @@ -295,4 +300,9 @@ public static T MapFromBlueprint(IRecord record, T blueprint) { return ((IRecordObjectMapping)Instance).MapFromBlueprint(record, blueprint); } + + private static void RegisterDefaultTypeConverters() + { + Instance._typeConversionManager.RegisterDefaultConverters(); + } } diff --git a/Neo4j.Driver/Neo4j.Driver/Public/Types/IVector.cs b/Neo4j.Driver/Neo4j.Driver/Public/Types/IVector.cs new file mode 100644 index 000000000..4e8c6cad3 --- /dev/null +++ b/Neo4j.Driver/Neo4j.Driver/Public/Types/IVector.cs @@ -0,0 +1,51 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; + +namespace Neo4j.Driver; + +/// +/// Represents a mathematical vector with elements of supported numeric types. +/// +/// +/// Supported element types are: , , , , +/// , and . +/// +public interface IVector +{ + /// Returns the elements of the vector as an array of objects, regardless of their underlying type. + IEnumerable UntypedValues { get; } + + /// Gets the original byte stream from which the vector was deserialized, if applicable. + byte[] OriginalByteStream { get; } + + /// Gets the type of the elements contained in the vector. + Type ElementType { get; } +} + +/// +/// Represents a mathematical vector with elements of supported numeric types. +/// +/// +public interface IVector : IEquatable, IVector, IReadOnlyList + where T : struct +{ + /// + /// Gets the array of values contained in the vector. + /// + IReadOnlyList Values { get; } +} diff --git a/Neo4j.Driver/Neo4j.Driver/Public/Types/Vector.cs b/Neo4j.Driver/Neo4j.Driver/Public/Types/Vector.cs new file mode 100644 index 000000000..6278251bd --- /dev/null +++ b/Neo4j.Driver/Neo4j.Driver/Public/Types/Vector.cs @@ -0,0 +1,179 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Neo4j.Driver.Internal; +using Neo4j.Driver.Internal.Types; + +namespace Neo4j.Driver; + +/// +/// An abstract base class for a mathematical vector with elements of supported numeric types. +/// +/// +/// Supported element types are: , , , , +/// , and . +/// +public abstract class Vector : IValue, IVector, IEquatable +{ + public static readonly HashSet SupportedTypes = + [ + typeof(float), // f32 + typeof(double), // f64 + typeof(sbyte), // i8 + typeof(short), // i16 + typeof(int), // i32 + typeof(long) // i64 + ]; + + /// + /// Determines whether the specified type is supported for use as a vector element. + /// + /// The type to check for support. + /// + /// true if the specified type is supported; otherwise, false. + /// + public static bool IsSupported(Type type) + { + return SupportedTypes.Contains(type); + } + + /// + /// Gets the elements of the vector as an array of objects, regardless of their underlying type. + /// + public abstract IEnumerable UntypedValues { get; } + + /// + /// Gets the original byte stream from which the vector was deserialized, if applicable. + /// + public abstract byte[] OriginalByteStream { get; } + + /// + /// Creates a new instance from the specified collection of values. + /// + /// The type of the vector elements. Must be a supported numeric type. + /// The collection of values to initialize the vector with. + /// The original byte stream from which the vector was deserialized, if applicable. + /// A new containing the specified values. + /// Thrown if is not a supported type. + public static Vector Create(T[] values, byte[] originalByteStream = null) where T : struct + { + if (!IsSupported(typeof(T))) + { + throw new NotSupportedException($"Type {typeof(T).Name} is not supported for Vector."); + } + + return new Vector(values, originalByteStream); + } + + private static readonly MethodInfo CreateMethodInfo = typeof(Vector).GetMethod(nameof(Create)); + + internal static Vector CreateDynamic(Array values, byte[] originalByteStream = null) + { + var elementType = values.GetType().GetElementType()!; + if (!IsSupported(elementType)) + { + throw new NotSupportedException($"Type {elementType.Name} is not supported for Vector."); + } + + // Use reflection to call the generic Create method + var genericMethod = CreateMethodInfo.MakeGenericMethod(elementType); + return (Vector)genericMethod.Invoke(null, [values, originalByteStream]); + } + + internal static Vector CreateDynamic(IEnumerable values, Type elementType, byte[] originalByteStream = null) + { + return CreateDynamic(values.Select(v => v.AsType(elementType)).ToArray(), originalByteStream); + } + + /// + /// Gets the type of the elements contained in the vector. + /// + public abstract Type ElementType { get; } + + /// + public bool Equals(IVector other) + { + return other != null && UntypedValues.SequenceEqual(other.UntypedValues); + } +} + +/// +/// Represents a mathematical vector with elements of a specific supported numeric type. +/// +/// +/// The type of the vector elements. Must be one of the supported numeric types: , +/// , , , , or . +/// +public class Vector : Vector, IVector where T : struct +{ + /// + /// Initializes a new instance of the class. + /// + /// + /// Thrown if is not a supported numeric type. + /// + public Vector() + { + if (!IsSupported(typeof(T))) + { + throw new NotSupportedException($"Type {typeof(T).Name} is not supported for Vector."); + } + } + + /// + /// Initializes a new instance of the class with the specified values. + /// + /// The array of values to initialize the vector with. Must not be null or empty. + /// The original byte stream from which the vector was deserialized, if applicable. + /// Thrown if is null or empty. + /// Thrown if is not a supported numeric type. + public Vector(T[] values, byte[] originalByteStream = null) : this() + { + Values = values ?? throw new ArgumentException("Values cannot be null.", nameof(values)); + OriginalByteStream = originalByteStream; + UntypedValues = Values.Select(x => (object)x); + } + + /// + /// Gets the array of values contained in the vector. + /// + public IReadOnlyList Values { get; } + + /// + public override IEnumerable UntypedValues { get; } + + /// + public override byte[] OriginalByteStream { get; } + + /// + public override Type ElementType => typeof(T); + + /// + public IEnumerator GetEnumerator() => Values.GetEnumerator(); + + /// + public int Count => Values.Count; + + /// + public T this[int index] => Values[index]; + + /// + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); +} diff --git a/Neo4j.Driver/Neo4j.Vector.Examples/Neo4j.Vector.Examples.csproj b/Neo4j.Driver/Neo4j.Vector.Examples/Neo4j.Vector.Examples.csproj new file mode 100644 index 000000000..31aa49caa --- /dev/null +++ b/Neo4j.Driver/Neo4j.Vector.Examples/Neo4j.Vector.Examples.csproj @@ -0,0 +1,28 @@ + + + + net9.0 + latest + enable + enable + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + diff --git a/Neo4j.Driver/Neo4j.Vector.Examples/VectorExamples.cs b/Neo4j.Driver/Neo4j.Vector.Examples/VectorExamples.cs new file mode 100644 index 000000000..34dafac8e --- /dev/null +++ b/Neo4j.Driver/Neo4j.Vector.Examples/VectorExamples.cs @@ -0,0 +1,136 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Neo4j.Driver.Mapping; + +namespace Neo4j.Vector.Examples; + +using System; +using Xunit; +using FluentAssertions; +using Driver; + +public class VectorExamplesTests : IDisposable +{ + private readonly IDriver? _driver; + + public VectorExamplesTests() + { + _driver = GraphDatabase.Driver("bolt://localhost:7687", AuthTokens.Basic("neo4j", "pass")); + _driver.ExecutableQuery("MATCH (n) DETACH DELETE n").ExecuteAsync().GetAwaiter().GetResult(); + RecordObjectMapping.Reset(); + } + + public void Dispose() + { + _driver?.Dispose(); + } + + [Fact] + public async Task ShouldWriteAndReadVector() + { + // create vector + var vectorElements = new double[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; + var doubleVector = new Vector(vectorElements); + + // write node with vector + await _driver! + .ExecutableQuery("CREATE (n:ShouldWriteAndReadVector {vector: $vector}) RETURN n") + .WithParameters(new { vector = doubleVector }) + .ExecuteAsync(); + + // read node with vector + var result = await _driver + .ExecutableQuery("MATCH (n:ShouldWriteAndReadVector) RETURN n") + .ExecuteAsync(); + + var record = result.Result[0]; + var node = (INode)record["n"]; + + // Here, we expect an array of doubles. + node.Properties.Should().ContainKey("vector"); + var vectorValue = node.Properties["vector"]; + + vectorValue.Should().BeOfType>(); + var vector = (Vector)vectorValue; + vector.Values.Should().Equal(vectorElements); + } + + [Fact] + public async Task ShouldWriteAndReadCSharpRecordWithVectors() + { + // create C# object with vectors + var record = new + { + DoubleVector = new Vector([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]), + LongVector = new Vector([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + }; + + // write node with vectors + await _driver! + .ExecutableQuery("CREATE (n:ShouldWriteAndReadCSharpRecordWithVectors $record) RETURN n") + .WithParameters(new { record }) + .ExecuteAsync(); + + // read node back + var eagerResult = await _driver + .ExecutableQuery("MATCH (n:ShouldWriteAndReadCSharpRecordWithVectors) RETURN n") + .ExecuteAsync(); + + var node = (INode)eagerResult.Result[0]["n"]; + + var doubleVectorRead = (Vector)node.Properties["DoubleVector"]; + var longVectorRead = (Vector)node.Properties["LongVector"]; + + doubleVectorRead.Should().BeEquivalentTo(record.DoubleVector); + longVectorRead.Should().BeEquivalentTo(record.LongVector); + } + + public class ClassForMappingWithVector(Vector doubleVector, Vector longVector) + { + public Vector DoubleVector { get; } = doubleVector; + public Vector LongVector { get; } = longVector; + } + + [Fact] + public async Task ShouldWorkWithObjectMapping() + { + // create C# vectors + var doubleVector = new Vector([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]); + var longVector = new Vector([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + // write node with vectors + await _driver! + .ExecutableQuery("CREATE (n:ShouldWorkWithObjectMapping $record) RETURN n") + .WithParameters(new { record = new { doubleVector, longVector } }) + .ExecuteAsync(); + + + // read node with vector + var result = await _driver + .ExecutableQuery( + """ + MATCH (n:ShouldWorkWithObjectMapping) + RETURN n.doubleVector AS doubleVector, n.longVector AS longVector + """) + .ExecuteAsync() + .AsObjectsAsync(); + + var record = result[0]; + record.DoubleVector.Should().BeEquivalentTo(doubleVector); + record.LongVector.Should().BeEquivalentTo(longVector); + } + +} diff --git a/Neo4j.Driver/Neo4j.Vector.Examples/docker-compose.yml b/Neo4j.Driver/Neo4j.Vector.Examples/docker-compose.yml new file mode 100644 index 000000000..7f7a72172 --- /dev/null +++ b/Neo4j.Driver/Neo4j.Vector.Examples/docker-compose.yml @@ -0,0 +1,12 @@ +# run `docker compose up` -d to start the Neo4j database running locally before running the examples + +services: + neo: + image: neo4j:5.26 + ports: + - 7687:7687 + - 7474:7474 + environment: + - NEO4J_AUTH=neo4j/pass + - NEO4J_ACCEPT_LICENSE_AGREEMENT=yes + - NEO4J_dbms_security_auth__minimum__password__length=4 \ No newline at end of file