diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferIndexWriter.java index 89a154d03..d4410a30c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferIndexWriter.java @@ -236,6 +236,7 @@ public void writeFloat(float v) { @Override public void writeFloats(float[] floats, int offset, int count) throws IOException { buffer.asFloatBuffer().put(floats, offset, count); + buffer.position(buffer.position() + count * Float.BYTES); } @Override diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorFloat.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorFloat.java index 573328337..32dce1f35 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorFloat.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorFloat.java @@ -16,9 +16,11 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.util.RamUsageEstimator; import io.github.jbellis.jvector.vector.types.VectorFloat; +import java.io.IOException; import java.util.Arrays; /** @@ -65,6 +67,11 @@ public int length() return data.length; } + @Override + public void writeTo(IndexWriter writer) throws IOException { + writer.writeFloats(this.get(), 0, data.length); + } + @Override public VectorFloat copy() { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorFloat.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorFloat.java index 636fa7d4d..e749ccdc5 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorFloat.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorFloat.java @@ -16,8 +16,11 @@ package io.github.jbellis.jvector.vector.types; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.util.Accountable; +import java.io.IOException; + public interface VectorFloat extends Accountable { /** @@ -31,6 +34,8 @@ default int offset(int i) { return i; } + void writeTo(IndexWriter indexWriter) throws IOException; + VectorFloat copy(); void copyFrom(VectorFloat src, int srcOffset, int destOffset, int length); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorFloat.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorFloat.java index 27df8d8b5..3e3f67e1f 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorFloat.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorFloat.java @@ -16,9 +16,11 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.util.RamUsageEstimator; import io.github.jbellis.jvector.vector.types.VectorFloat; +import java.io.IOException; import java.lang.foreign.MemorySegment; import java.nio.Buffer; @@ -86,6 +88,13 @@ public int offset(int i) return i * Float.BYTES; } + @Override + public void writeTo(IndexWriter writer) throws IOException { + for (int i = 0; i < length(); i++) { + writer.writeFloat(get(i)); + } + } + @Override public VectorFloat copy() { diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestByteBufferIndexWriter.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestByteBufferIndexWriter.java new file mode 100644 index 000000000..d84fa5876 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestByteBufferIndexWriter.java @@ -0,0 +1,246 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.disk; + +import io.github.jbellis.jvector.LuceneTestCase; +import io.github.jbellis.jvector.vector.VectorUtil; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.Assert.*; + +public class TestByteBufferIndexWriter extends LuceneTestCase { + + @Test + public void testWriteFloatsAdvancesPosition() throws IOException { + ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false); + + float[] floats = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + + // Write floats + writer.writeFloats(floats, 0, floats.length); + + // Verify position advanced correctly + assertEquals(floats.length * Float.BYTES, writer.position()); + + // Write more data to ensure position is correct + writer.writeInt(42); + assertEquals(floats.length * Float.BYTES + Integer.BYTES, writer.position()); + } + + @Test + public void testWriteFloatsWithOffsetAndCount() throws IOException { + ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false); + + float[] floats = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + + // Write only middle 3 elements + writer.writeFloats(floats, 1, 3); + + // Verify position + assertEquals(3 * Float.BYTES, writer.position()); + + // Verify content + ByteBuffer buffer = writer.getWrittenData(); + assertEquals(2.0f, buffer.getFloat(), 0.001f); + assertEquals(3.0f, buffer.getFloat(), 0.001f); + assertEquals(4.0f, buffer.getFloat(), 0.001f); + } + + @Test + public void testWriteFloatVectorIntegration() throws IOException { + VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false); + + // Create a vector + float[] data = {1.5f, 2.5f, 3.5f, 4.5f}; + VectorFloat vector = vts.createFloatVector(data); + + // Write vector using VectorTypeSupport + vts.writeFloatVector(writer, vector); + + // Verify position advanced + assertEquals(data.length * Float.BYTES, writer.position()); + + // Verify content + ByteBuffer buffer = writer.getWrittenData(); + for (float expected : data) { + assertEquals(expected, buffer.getFloat(), 0.001f); + } + } + + @Test + public void testMultipleWriteFloatsCalls() throws IOException { + ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false); + + float[] floats1 = {1.0f, 2.0f}; + float[] floats2 = {3.0f, 4.0f, 5.0f}; + + writer.writeFloats(floats1, 0, floats1.length); + long pos1 = writer.position(); + assertEquals(floats1.length * Float.BYTES, pos1); + + writer.writeFloats(floats2, 0, floats2.length); + long pos2 = writer.position(); + assertEquals((floats1.length + floats2.length) * Float.BYTES, pos2); + + // Verify all data was written correctly + ByteBuffer buffer = writer.getWrittenData(); + assertEquals(1.0f, buffer.getFloat(), 0.001f); + assertEquals(2.0f, buffer.getFloat(), 0.001f); + assertEquals(3.0f, buffer.getFloat(), 0.001f); + assertEquals(4.0f, buffer.getFloat(), 0.001f); + assertEquals(5.0f, buffer.getFloat(), 0.001f); + } + + @Test + public void testWriteFloatsDoesNotOverwrite() throws IOException { + ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false); + + // Write an int first + writer.writeInt(999); + long posAfterInt = writer.position(); + + // Write floats + float[] floats = {1.0f, 2.0f, 3.0f}; + writer.writeFloats(floats, 0, floats.length); + + // Verify position + assertEquals(posAfterInt + floats.length * Float.BYTES, writer.position()); + + // Verify the int wasn't overwritten + ByteBuffer buffer = writer.getWrittenData(); + assertEquals(999, buffer.getInt()); + assertEquals(1.0f, buffer.getFloat(), 0.001f); + assertEquals(2.0f, buffer.getFloat(), 0.001f); + assertEquals(3.0f, buffer.getFloat(), 0.001f); + } + + @Test + public void testWriteFloatsWithDirectBuffer() throws IOException { + ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, true); + + float[] floats = {1.0f, 2.0f, 3.0f}; + writer.writeFloats(floats, 0, floats.length); + + assertEquals(floats.length * Float.BYTES, writer.position()); + + ByteBuffer buffer = writer.getWrittenData(); + for (float expected : floats) { + assertEquals(expected, buffer.getFloat(), 0.001f); + } + } + + @Test + public void testBytesWritten() throws IOException { + ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false); + + assertEquals(0, writer.bytesWritten()); + + float[] floats = {1.0f, 2.0f, 3.0f}; + writer.writeFloats(floats, 0, floats.length); + + assertEquals(floats.length * Float.BYTES, writer.bytesWritten()); + + writer.writeInt(42); + assertEquals(floats.length * Float.BYTES + Integer.BYTES, writer.bytesWritten()); + } + + @Test + public void testArrayVectorFloatWriteTo() throws IOException { + VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false); + + // Create an ArrayVectorFloat + float[] data = {1.5f, 2.5f, 3.5f, 4.5f, 5.5f}; + VectorFloat vector = vts.createFloatVector(data); + + // Call writeTo which internally calls writeFloats + vector.writeTo(writer); + + // Verify position advanced correctly + assertEquals(data.length * Float.BYTES, writer.position()); + assertEquals(data.length * Float.BYTES, writer.bytesWritten()); + + // Verify content was written correctly + ByteBuffer buffer = writer.getWrittenData(); + for (float expected : data) { + assertEquals(expected, buffer.getFloat(), 0.001f); + } + } + + @Test + public void testArrayVectorFloatWriteToMultipleTimes() throws IOException { + VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false); + + // Create two vectors + float[] data1 = {1.0f, 2.0f, 3.0f}; + float[] data2 = {4.0f, 5.0f, 6.0f, 7.0f}; + VectorFloat vector1 = vts.createFloatVector(data1); + VectorFloat vector2 = vts.createFloatVector(data2); + + // Write first vector + vector1.writeTo(writer); + long pos1 = writer.position(); + assertEquals(data1.length * Float.BYTES, pos1); + + // Write second vector + vector2.writeTo(writer); + long pos2 = writer.position(); + assertEquals((data1.length + data2.length) * Float.BYTES, pos2); + + // Verify all data was written correctly + ByteBuffer buffer = writer.getWrittenData(); + assertEquals(1.0f, buffer.getFloat(), 0.001f); + assertEquals(2.0f, buffer.getFloat(), 0.001f); + assertEquals(3.0f, buffer.getFloat(), 0.001f); + assertEquals(4.0f, buffer.getFloat(), 0.001f); + assertEquals(5.0f, buffer.getFloat(), 0.001f); + assertEquals(6.0f, buffer.getFloat(), 0.001f); + assertEquals(7.0f, buffer.getFloat(), 0.001f); + } + + @Test + public void testArrayVectorFloatWriteToWithDirectBuffer() throws IOException { + VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, true); + + // Create a vector + float[] data = {10.5f, 20.5f, 30.5f}; + VectorFloat vector = vts.createFloatVector(data); + + // Write to direct buffer + vector.writeTo(writer); + + // Verify position + assertEquals(data.length * Float.BYTES, writer.position()); + + // Verify content + ByteBuffer buffer = writer.getWrittenData(); + for (float expected : data) { + assertEquals(expected, buffer.getFloat(), 0.001f); + } + } +} + +// Made with Bob