From 9fe622588feb0909364d43764187f5819827c83d Mon Sep 17 00:00:00 2001 From: Artem Yegorov Date: Thu, 21 May 2026 15:17:23 +0300 Subject: [PATCH 1/5] feat: add keyword argument support for `train_from_stream` (#114) --- lib/classifier/bayes.rb | 27 ++++++++------ lib/classifier/knn.rb | 6 +-- lib/classifier/logistic_regression.rb | 43 ++++++++++++---------- lib/classifier/lsi.rb | 38 ++++++++++--------- lib/classifier/streaming.rb | 4 +- test/bayes/streaming_test.rb | 17 +++++++++ test/knn/streaming_test.rb | 29 +++++++++++++++ test/logistic_regression/streaming_test.rb | 31 ++++++++++++++++ test/lsi/streaming_test.rb | 17 +++++++++ 9 files changed, 158 insertions(+), 54 deletions(-) create mode 100644 test/knn/streaming_test.rb create mode 100644 test/logistic_regression/streaming_test.rb diff --git a/lib/classifier/bayes.rb b/lib/classifier/bayes.rb index 09a9a628..b2b0b92b 100644 --- a/lib/classifier/bayes.rb +++ b/lib/classifier/bayes.rb @@ -328,20 +328,23 @@ def remove_category(category) # puts "#{progress.completed} documents processed" # end # - # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void - def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE) - category = category.prepare_category_name - raise StandardError, "No such category: #{category}" unless @categories.key?(category) + # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories) + (category && io ? { category => io } : categories).each do |category, io| + category = category.prepare_category_name + raise StandardError, "No such category: #{category}" unless @categories.key?(category) + raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) - reader = Streaming::LineReader.new(io, batch_size: batch_size) - total = reader.estimate_line_count - progress = Streaming::Progress.new(total: total) + reader = Streaming::LineReader.new(io, batch_size: batch_size) + total = reader.estimate_line_count + progress = Streaming::Progress.new(total: total) - reader.each_batch do |batch| - train_batch_internal(category, batch) - progress.completed += batch.size - progress.current_batch += 1 - yield progress if block_given? + reader.each_batch do |batch| + train_batch_internal(category, batch) + progress.completed += batch.size + progress.current_batch += 1 + yield progress if block_given? + end end end diff --git a/lib/classifier/knn.rb b/lib/classifier/knn.rb index 37ae8174..a9e29f95 100644 --- a/lib/classifier/knn.rb +++ b/lib/classifier/knn.rb @@ -268,9 +268,9 @@ def self.load_checkpoint(storage:, checkpoint_id:) # puts "#{progress.completed} documents processed" # end # - # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void - def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE, &block) - @lsi.train_from_stream(category, io, batch_size: batch_size, &block) + # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block) + @lsi.train_from_stream(category, io, batch_size: batch_size, **categories, &block) synchronize { @dirty = true } end diff --git a/lib/classifier/logistic_regression.rb b/lib/classifier/logistic_regression.rb index 72de3176..a9d738de 100644 --- a/lib/classifier/logistic_regression.rb +++ b/lib/classifier/logistic_regression.rb @@ -390,28 +390,31 @@ def self.load_checkpoint(storage:, checkpoint_id:) # end # classifier.fit # - # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void - def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE) - category = category.to_s.prepare_category_name - raise StandardError, "No such category: #{category}" unless @categories.include?(category) - - reader = Streaming::LineReader.new(io, batch_size: batch_size) - total = reader.estimate_line_count - progress = Streaming::Progress.new(total: total) - - reader.each_batch do |batch| - synchronize do - batch.each do |text| - features = text.word_hash(@min_word_length) - features.each_key { |word| @vocabulary[word] = true } - @training_data << { category: category, features: features } + # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories) + (category && io ? { category => io } : categories).each do |category, io| + category = category.to_s.prepare_category_name + raise StandardError, "No such category: #{category}" unless @categories.include?(category) + raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) + + reader = Streaming::LineReader.new(io, batch_size: batch_size) + total = reader.estimate_line_count + progress = Streaming::Progress.new(total: total) + + reader.each_batch do |batch| + synchronize do + batch.each do |text| + features = text.word_hash(@min_word_length) + features.each_key { |word| @vocabulary[word] = true } + @training_data << { category: category, features: features } + end + @fitted = false + @dirty = true end - @fitted = false - @dirty = true + progress.completed += batch.size + progress.current_batch += 1 + yield progress if block_given? end - progress.completed += batch.size - progress.current_batch += 1 - yield progress if block_given? end end diff --git a/lib/classifier/lsi.rb b/lib/classifier/lsi.rb index 5c277150..f81ed929 100644 --- a/lib/classifier/lsi.rb +++ b/lib/classifier/lsi.rb @@ -662,25 +662,29 @@ def self.load_checkpoint(storage:, checkpoint_id:) # puts "#{progress.completed} documents processed" # end # - # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void - def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE) - original_auto_rebuild = @auto_rebuild - @auto_rebuild = false + # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories) + (category && io ? { category => io } : categories).each do |category, io| + raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) - begin - reader = Streaming::LineReader.new(io, batch_size: batch_size) - total = reader.estimate_line_count - progress = Streaming::Progress.new(total: total) - - reader.each_batch do |batch| - batch.each { |text| add_item(text, category) } - progress.completed += batch.size - progress.current_batch += 1 - yield progress if block_given? + original_auto_rebuild = @auto_rebuild + @auto_rebuild = false + + begin + reader = Streaming::LineReader.new(io, batch_size: batch_size) + total = reader.estimate_line_count + progress = Streaming::Progress.new(total: total) + + reader.each_batch do |batch| + batch.each { |text| add_item(text, category) } + progress.completed += batch.size + progress.current_batch += 1 + yield progress if block_given? + end + ensure + @auto_rebuild = original_auto_rebuild + build_index if original_auto_rebuild end - ensure - @auto_rebuild = original_auto_rebuild - build_index if original_auto_rebuild end end diff --git a/lib/classifier/streaming.rb b/lib/classifier/streaming.rb index 3c228b41..3560e702 100644 --- a/lib/classifier/streaming.rb +++ b/lib/classifier/streaming.rb @@ -26,8 +26,8 @@ module Streaming # Trains the classifier from an IO stream. # Each line in the stream is treated as a separate document. # - # @rbs (Symbol | String, IO, ?batch_size: Integer) { (Progress) -> void } -> void - def train_from_stream(category, io, batch_size: DEFAULT_BATCH_SIZE, &block) + # @rbs (Symbol | String, IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Progress) -> void } -> void + def train_from_stream(category, io, batch_size: DEFAULT_BATCH_SIZE, **categories, &block) raise NotImplementedError, "#{self.class} must implement train_from_stream" end diff --git a/test/bayes/streaming_test.rb b/test/bayes/streaming_test.rb index fd2f311a..73771ad0 100644 --- a/test/bayes/streaming_test.rb +++ b/test/bayes/streaming_test.rb @@ -17,6 +17,23 @@ def test_train_from_stream_basic assert_equal 'Spam', @classifier.classify('buy cheap free') end + def test_train_from_stream_many_categories + classifier = Classifier::Bayes.new('Spam', 'Ham') + classifier.train_from_stream( + spam: StringIO.new("buy now cheap\nfree money\nlimited offer\n"), + ham: StringIO.new("hello friend\nmeeting tomorrow\n") + ) + + assert_equal 'Spam', classifier.classify('buy free') + assert_equal 'Ham', classifier.classify('hello meeting') + end + + def test_train_from_stream_invalid_io_type + assert_raises(StandardError) do + @classifier.train_from_stream(spam: Object.new) + end + end + def test_train_from_stream_empty_io io = StringIO.new('') @classifier.train_from_stream(:spam, io) diff --git a/test/knn/streaming_test.rb b/test/knn/streaming_test.rb new file mode 100644 index 00000000..f71885fb --- /dev/null +++ b/test/knn/streaming_test.rb @@ -0,0 +1,29 @@ +require_relative '../test_helper' +require 'stringio' + +class KNNStreamingTest < Minitest::Test + def test_train_from_stream_basic + knn = Classifier::KNN.new + knn.train_from_stream(:spam, StringIO.new("buy now cheap\nfree money\nlimited offer\n")) + + assert_equal 'spam', knn.classify('buy cheap free') + end + + def test_train_from_stream_many_categories + knn = Classifier::KNN.new + knn.train_from_stream( + spam: StringIO.new("buy now cheap\nfree money\nlimited offer\n"), + ham: StringIO.new("hello friend\nmeeting tomorrow\nhello fellow\n") + ) + + assert_equal 'spam', knn.classify('free offer') + assert_equal 'ham', knn.classify('hello') + end + + def test_train_from_stream_invalid_io_type + knn = Classifier::KNN.new + assert_raises(StandardError) do + knn.train_from_stream(spam: Object.new) + end + end +end diff --git a/test/logistic_regression/streaming_test.rb b/test/logistic_regression/streaming_test.rb new file mode 100644 index 00000000..ac06568a --- /dev/null +++ b/test/logistic_regression/streaming_test.rb @@ -0,0 +1,31 @@ +require_relative '../test_helper' +require 'stringio' + +class LogisticRegressionStreamingTest < Minitest::Test + def test_train_from_stream_basic + classifier = Classifier::LogisticRegression.new('Spam', 'Ham') + classifier.train_from_stream(:spam, StringIO.new("buy now cheap\nfree money\nlimited offer\n")) + classifier.fit + + assert_equal 'Spam', classifier.classify('buy cheap free') + end + + def test_train_from_stream_many_categories + classifier = Classifier::LogisticRegression.new('Spam', 'Ham') + classifier.train_from_stream( + spam: StringIO.new("buy now cheap\nfree money\nlimited offer\n"), + ham: StringIO.new("hello friend\nmeeting tomorrow\n") + ) + classifier.fit + + assert_equal 'Spam', classifier.classify('buy free') + assert_equal 'Ham', classifier.classify('hello meeting') + end + + def test_train_from_stream_invalid_io_type + classifier = Classifier::LogisticRegression.new('Spam', 'Ham') + assert_raises(StandardError) do + classifier.train_from_stream(spam: Object.new) + end + end +end diff --git a/test/lsi/streaming_test.rb b/test/lsi/streaming_test.rb index 94bfaafb..c9327dfb 100644 --- a/test/lsi/streaming_test.rb +++ b/test/lsi/streaming_test.rb @@ -23,6 +23,23 @@ def test_train_from_stream_basic assert_equal 'dog', result.to_s end + def test_train_from_stream_many_categories + lsi = Classifier::LSI.new + lsi.train_from_stream( + dog: StringIO.new("dogs are loyal pets\npuppies are playful\ndogs bark at strangers\n"), + cat: StringIO.new("cats are independent\nkittens are curious\ncats meow softly\n") + ) + + assert_equal :dog, lsi.classify('loyal pet that barks') + assert_equal :cat, lsi.classify('independent curious pet') + end + + def test_train_from_stream_invalid_io_type + assert_raises(StandardError) do + @lsi.train_from_stream(category: Object.new) + end + end + def test_train_from_stream_empty_io @lsi.train_from_stream(:category, StringIO.new('')) From 7c8837fb70069459e07ed7446ebc874354909514 Mon Sep 17 00:00:00 2001 From: Artem Yegorov Date: Thu, 21 May 2026 15:48:51 +0300 Subject: [PATCH 2/5] fix: greptile comments (#114) --- lib/classifier/bayes.rb | 3 +++ lib/classifier/lsi.rb | 29 +++++++++++++---------------- lib/classifier/streaming.rb | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/lib/classifier/bayes.rb b/lib/classifier/bayes.rb index b2b0b92b..5959e427 100644 --- a/lib/classifier/bayes.rb +++ b/lib/classifier/bayes.rb @@ -330,6 +330,9 @@ def remove_category(category) # # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories) + raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? + raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil? + (category && io ? { category => io } : categories).each do |category, io| category = category.prepare_category_name raise StandardError, "No such category: #{category}" unless @categories.key?(category) diff --git a/lib/classifier/lsi.rb b/lib/classifier/lsi.rb index f81ed929..37f07ff7 100644 --- a/lib/classifier/lsi.rb +++ b/lib/classifier/lsi.rb @@ -664,28 +664,25 @@ def self.load_checkpoint(storage:, checkpoint_id:) # # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories) + original_auto_rebuild = @auto_rebuild + @auto_rebuild = false (category && io ? { category => io } : categories).each do |category, io| raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) - original_auto_rebuild = @auto_rebuild - @auto_rebuild = false + reader = Streaming::LineReader.new(io, batch_size: batch_size) + total = reader.estimate_line_count + progress = Streaming::Progress.new(total: total) - begin - reader = Streaming::LineReader.new(io, batch_size: batch_size) - total = reader.estimate_line_count - progress = Streaming::Progress.new(total: total) - - reader.each_batch do |batch| - batch.each { |text| add_item(text, category) } - progress.completed += batch.size - progress.current_batch += 1 - yield progress if block_given? - end - ensure - @auto_rebuild = original_auto_rebuild - build_index if original_auto_rebuild + reader.each_batch do |batch| + batch.each { |text| add_item(text, category) } + progress.completed += batch.size + progress.current_batch += 1 + yield progress if block_given? end end + ensure + @auto_rebuild = original_auto_rebuild + build_index if original_auto_rebuild end # Adds items to the index in batches from an array. diff --git a/lib/classifier/streaming.rb b/lib/classifier/streaming.rb index 3560e702..6b80f44b 100644 --- a/lib/classifier/streaming.rb +++ b/lib/classifier/streaming.rb @@ -26,8 +26,8 @@ module Streaming # Trains the classifier from an IO stream. # Each line in the stream is treated as a separate document. # - # @rbs (Symbol | String, IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Progress) -> void } -> void - def train_from_stream(category, io, batch_size: DEFAULT_BATCH_SIZE, **categories, &block) + # @rbs (?(Symbol | String), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Progress) -> void } -> void + def train_from_stream(category = nil, io = nil, batch_size: DEFAULT_BATCH_SIZE, **categories, &block) raise NotImplementedError, "#{self.class} must implement train_from_stream" end From 069db71a1c59a747c508e112d42b088d2019ddde Mon Sep 17 00:00:00 2001 From: Artem Yegorov Date: Thu, 21 May 2026 16:13:40 +0300 Subject: [PATCH 3/5] fix: rubocop offenses (#114) --- lib/classifier/bayes.rb | 36 ++++++++++-------- lib/classifier/logistic_regression.rb | 55 ++++++++++++++++----------- lib/classifier/lsi.rb | 35 ++++++++++------- 3 files changed, 75 insertions(+), 51 deletions(-) diff --git a/lib/classifier/bayes.rb b/lib/classifier/bayes.rb index 5959e427..f9fdb459 100644 --- a/lib/classifier/bayes.rb +++ b/lib/classifier/bayes.rb @@ -329,25 +329,12 @@ def remove_category(category) # end # # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void - def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories) + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil? (category && io ? { category => io } : categories).each do |category, io| - category = category.prepare_category_name - raise StandardError, "No such category: #{category}" unless @categories.key?(category) - raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) - - reader = Streaming::LineReader.new(io, batch_size: batch_size) - total = reader.estimate_line_count - progress = Streaming::Progress.new(total: total) - - reader.each_batch do |batch| - train_batch_internal(category, batch) - progress.completed += batch.size - progress.current_batch += 1 - yield progress if block_given? - end + stream_train_category(category, io, batch_size: batch_size, &) end end @@ -395,6 +382,25 @@ def self.load_checkpoint(storage:, checkpoint_id:) private + # Trains from an IO stream with a single category. + # @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void + def stream_train_category(category, io, batch_size:) + category = category.prepare_category_name + raise StandardError, "No such category: #{category}" unless @categories.key?(category) + raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) + + reader = Streaming::LineReader.new(io, batch_size: batch_size) + total = reader.estimate_line_count + progress = Streaming::Progress.new(total: total) + + reader.each_batch do |batch| + train_batch_internal(category, batch) + progress.completed += batch.size + progress.current_batch += 1 + yield progress if block_given? + end + end + # Trains a batch of documents for a single category. # @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE) diff --git a/lib/classifier/logistic_regression.rb b/lib/classifier/logistic_regression.rb index a9d738de..35a3543d 100644 --- a/lib/classifier/logistic_regression.rb +++ b/lib/classifier/logistic_regression.rb @@ -391,30 +391,12 @@ def self.load_checkpoint(storage:, checkpoint_id:) # classifier.fit # # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void - def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories) + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) + raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? + raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil? + (category && io ? { category => io } : categories).each do |category, io| - category = category.to_s.prepare_category_name - raise StandardError, "No such category: #{category}" unless @categories.include?(category) - raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) - - reader = Streaming::LineReader.new(io, batch_size: batch_size) - total = reader.estimate_line_count - progress = Streaming::Progress.new(total: total) - - reader.each_batch do |batch| - synchronize do - batch.each do |text| - features = text.word_hash(@min_word_length) - features.each_key { |word| @vocabulary[word] = true } - @training_data << { category: category, features: features } - end - @fitted = false - @dirty = true - end - progress.completed += batch.size - progress.current_batch += 1 - yield progress if block_given? - end + stream_train_category(category, io, batch_size:, &) end end @@ -443,6 +425,33 @@ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_ private + # Trains from an IO stream with a single category. + # @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void + def stream_train_category(category, io, batch_size:) + category = category.to_s.prepare_category_name + raise StandardError, "No such category: #{category}" unless @categories.include?(category) + raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) + + reader = Streaming::LineReader.new(io, batch_size: batch_size) + total = reader.estimate_line_count + progress = Streaming::Progress.new(total: total) + + reader.each_batch do |batch| + synchronize do + batch.each do |text| + features = text.word_hash(@min_word_length) + features.each_key { |word| @vocabulary[word] = true } + @training_data << { category: category, features: features } + end + @fitted = false + @dirty = true + end + progress.completed += batch.size + progress.current_batch += 1 + yield progress if block_given? + end + end + # Trains a batch of documents for a single category. # @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE) diff --git a/lib/classifier/lsi.rb b/lib/classifier/lsi.rb index 37f07ff7..90f8f10a 100644 --- a/lib/classifier/lsi.rb +++ b/lib/classifier/lsi.rb @@ -663,22 +663,14 @@ def self.load_checkpoint(storage:, checkpoint_id:) # end # # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void - def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories) + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) + raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? + raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil? + original_auto_rebuild = @auto_rebuild @auto_rebuild = false (category && io ? { category => io } : categories).each do |category, io| - raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) - - reader = Streaming::LineReader.new(io, batch_size: batch_size) - total = reader.estimate_line_count - progress = Streaming::Progress.new(total: total) - - reader.each_batch do |batch| - batch.each { |text| add_item(text, category) } - progress.completed += batch.size - progress.current_batch += 1 - yield progress if block_given? - end + stream_train_category(category, io, batch_size:, &) end ensure @auto_rebuild = original_auto_rebuild @@ -730,6 +722,23 @@ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_ private + # Trains from an IO stream with a single category. + # @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void + def stream_train_category(category, io, batch_size:) + raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) + + reader = Streaming::LineReader.new(io, batch_size: batch_size) + total = reader.estimate_line_count + progress = Streaming::Progress.new(total: total) + + reader.each_batch do |batch| + batch.each { |text| add_item(text, category) } + progress.completed += batch.size + progress.current_batch += 1 + yield progress if block_given? + end + end + # Restores LSI state from a JSON string (used by reload) # @rbs (String) -> void def restore_from_json(json) From 0de3ae21e7a1046fe5ba9afdb08064f3a5e1146d Mon Sep 17 00:00:00 2001 From: Artem Yegorov Date: Thu, 21 May 2026 16:39:30 +0300 Subject: [PATCH 4/5] fix: rbs problems (#114) --- lib/classifier/bayes.rb | 4 ++-- lib/classifier/knn.rb | 3 ++- lib/classifier/logistic_regression.rb | 4 ++-- lib/classifier/lsi.rb | 4 ++-- lib/classifier/streaming.rb | 2 +- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/lib/classifier/bayes.rb b/lib/classifier/bayes.rb index f9fdb459..59272b26 100644 --- a/lib/classifier/bayes.rb +++ b/lib/classifier/bayes.rb @@ -328,12 +328,12 @@ def remove_category(category) # puts "#{progress.completed} documents processed" # end # - # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void + # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil? - (category && io ? { category => io } : categories).each do |category, io| + (category && io ? { category => io } : categories).each do |(category, io)| stream_train_category(category, io, batch_size: batch_size, &) end end diff --git a/lib/classifier/knn.rb b/lib/classifier/knn.rb index a9e29f95..1a519e5e 100644 --- a/lib/classifier/knn.rb +++ b/lib/classifier/knn.rb @@ -268,8 +268,9 @@ def self.load_checkpoint(storage:, checkpoint_id:) # puts "#{progress.completed} documents processed" # end # - # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void + # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block) + # @type var categories: untyped @lsi.train_from_stream(category, io, batch_size: batch_size, **categories, &block) synchronize { @dirty = true } end diff --git a/lib/classifier/logistic_regression.rb b/lib/classifier/logistic_regression.rb index 35a3543d..1ba414b7 100644 --- a/lib/classifier/logistic_regression.rb +++ b/lib/classifier/logistic_regression.rb @@ -390,12 +390,12 @@ def self.load_checkpoint(storage:, checkpoint_id:) # end # classifier.fit # - # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void + # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil? - (category && io ? { category => io } : categories).each do |category, io| + (category && io ? { category => io } : categories).each do |(category, io)| stream_train_category(category, io, batch_size:, &) end end diff --git a/lib/classifier/lsi.rb b/lib/classifier/lsi.rb index 90f8f10a..37aa8d3e 100644 --- a/lib/classifier/lsi.rb +++ b/lib/classifier/lsi.rb @@ -662,14 +662,14 @@ def self.load_checkpoint(storage:, checkpoint_id:) # puts "#{progress.completed} documents processed" # end # - # @rbs (?(String | Symbol), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Streaming::Progress) -> void } -> void + # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil? original_auto_rebuild = @auto_rebuild @auto_rebuild = false - (category && io ? { category => io } : categories).each do |category, io| + (category && io ? { category => io } : categories).each do |(category, io)| stream_train_category(category, io, batch_size:, &) end ensure diff --git a/lib/classifier/streaming.rb b/lib/classifier/streaming.rb index 6b80f44b..329dabd3 100644 --- a/lib/classifier/streaming.rb +++ b/lib/classifier/streaming.rb @@ -26,7 +26,7 @@ module Streaming # Trains the classifier from an IO stream. # Each line in the stream is treated as a separate document. # - # @rbs (?(Symbol | String), ?IO, ?batch_size: Integer, **Hash[Symbol, IO]) { (Progress) -> void } -> void + # @rbs (?(Symbol | String | nil), ?IO?, ?batch_size: Integer, **Hash[Symbol, IO]) { (Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: DEFAULT_BATCH_SIZE, **categories, &block) raise NotImplementedError, "#{self.class} must implement train_from_stream" end From 19bd731f68975fb4b0e346b09e3f5772c87090b1 Mon Sep 17 00:00:00 2001 From: Artem Yegorov Date: Sun, 31 May 2026 21:27:17 +0300 Subject: [PATCH 5/5] fix: some comments (#114) --- lib/classifier/bayes.rb | 11 +++--- lib/classifier/knn.rb | 6 +-- lib/classifier/logistic_regression.rb | 11 +++--- lib/classifier/lsi.rb | 26 ++++++++----- lib/classifier/streaming.rb | 2 +- test/bayes/streaming_test.rb | 2 +- test/knn/streaming_test.rb | 14 +++++-- test/logistic_regression/streaming_test.rb | 14 +++++-- test/lsi/streaming_test.rb | 44 +++++++++++++++++++++- 9 files changed, 97 insertions(+), 33 deletions(-) diff --git a/lib/classifier/bayes.rb b/lib/classifier/bayes.rb index 59272b26..622360d4 100644 --- a/lib/classifier/bayes.rb +++ b/lib/classifier/bayes.rb @@ -331,10 +331,11 @@ def remove_category(category) # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? - raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil? + raise ArgumentError, 'Provide both category and io, or use keyword arguments' if [category, io].one?(&:nil?) - (category && io ? { category => io } : categories).each do |(category, io)| - stream_train_category(category, io, batch_size: batch_size, &) + pairs = category && io ? { category => io } : categories + pairs.each do |cat, stream| + stream_train_category(cat, stream, batch_size: batch_size, &) end end @@ -386,8 +387,8 @@ def self.load_checkpoint(storage:, checkpoint_id:) # @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void def stream_train_category(category, io, batch_size:) category = category.prepare_category_name - raise StandardError, "No such category: #{category}" unless @categories.key?(category) - raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) + raise ArgumentError, "No such category: #{category}" unless @categories.key?(category) + raise ArgumentError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) reader = Streaming::LineReader.new(io, batch_size: batch_size) total = reader.estimate_line_count diff --git a/lib/classifier/knn.rb b/lib/classifier/knn.rb index 1a519e5e..0c06dfa5 100644 --- a/lib/classifier/knn.rb +++ b/lib/classifier/knn.rb @@ -269,9 +269,9 @@ def self.load_checkpoint(storage:, checkpoint_id:) # end # # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void - def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block) - # @type var categories: untyped - @lsi.train_from_stream(category, io, batch_size: batch_size, **categories, &block) + def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) + # @type var categories: untype + @lsi.train_from_stream(category, io, batch_size: batch_size, **categories, &) synchronize { @dirty = true } end diff --git a/lib/classifier/logistic_regression.rb b/lib/classifier/logistic_regression.rb index 1ba414b7..0d1819e9 100644 --- a/lib/classifier/logistic_regression.rb +++ b/lib/classifier/logistic_regression.rb @@ -393,10 +393,11 @@ def self.load_checkpoint(storage:, checkpoint_id:) # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? - raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil? + raise ArgumentError, 'Provide both category and io, or use keyword arguments' if [category, io].one?(&:nil?) - (category && io ? { category => io } : categories).each do |(category, io)| - stream_train_category(category, io, batch_size:, &) + pairs = category && io ? { category => io } : categories + pairs.each do |cat, stream| + stream_train_category(cat, stream, batch_size:, &) end end @@ -429,8 +430,8 @@ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_ # @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void def stream_train_category(category, io, batch_size:) category = category.to_s.prepare_category_name - raise StandardError, "No such category: #{category}" unless @categories.include?(category) - raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) + raise ArgumentError, "No such category: #{category}" unless @categories.include?(category) + raise ArgumentError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) reader = Streaming::LineReader.new(io, batch_size: batch_size) total = reader.estimate_line_count diff --git a/lib/classifier/lsi.rb b/lib/classifier/lsi.rb index 37aa8d3e..3476cd90 100644 --- a/lib/classifier/lsi.rb +++ b/lib/classifier/lsi.rb @@ -662,19 +662,27 @@ def self.load_checkpoint(storage:, checkpoint_id:) # puts "#{progress.completed} documents processed" # end # + # rubocop:disable Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity # @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) + # rubocop:enable Metrics/CyclomaticComplexity, Metrics/PerceivedComplexity raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? - raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil? + raise ArgumentError, 'Provide both category and io, or use keyword arguments' if [category, io].one?(&:nil?) - original_auto_rebuild = @auto_rebuild - @auto_rebuild = false - (category && io ? { category => io } : categories).each do |(category, io)| - stream_train_category(category, io, batch_size:, &) + pairs = category && io ? { category => io } : categories + pairs.each_value do |io| + raise ArgumentError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) + end + begin + original_auto_rebuild = @auto_rebuild + @auto_rebuild = false + pairs.each do |cat, stream| + stream_train_category(cat, stream, batch_size:, &) + end + ensure + @auto_rebuild = original_auto_rebuild + build_index if original_auto_rebuild end - ensure - @auto_rebuild = original_auto_rebuild - build_index if original_auto_rebuild end # Adds items to the index in batches from an array. @@ -725,8 +733,6 @@ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_ # Trains from an IO stream with a single category. # @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void def stream_train_category(category, io, batch_size:) - raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line) - reader = Streaming::LineReader.new(io, batch_size: batch_size) total = reader.estimate_line_count progress = Streaming::Progress.new(total: total) diff --git a/lib/classifier/streaming.rb b/lib/classifier/streaming.rb index 329dabd3..72382670 100644 --- a/lib/classifier/streaming.rb +++ b/lib/classifier/streaming.rb @@ -26,7 +26,7 @@ module Streaming # Trains the classifier from an IO stream. # Each line in the stream is treated as a separate document. # - # @rbs (?(Symbol | String | nil), ?IO?, ?batch_size: Integer, **Hash[Symbol, IO]) { (Progress) -> void } -> void + # @rbs (?(Symbol | String | nil), ?IO?, ?batch_size: Integer, **IO) { (Progress) -> void } -> void def train_from_stream(category = nil, io = nil, batch_size: DEFAULT_BATCH_SIZE, **categories, &block) raise NotImplementedError, "#{self.class} must implement train_from_stream" end diff --git a/test/bayes/streaming_test.rb b/test/bayes/streaming_test.rb index 73771ad0..16b0a75d 100644 --- a/test/bayes/streaming_test.rb +++ b/test/bayes/streaming_test.rb @@ -29,7 +29,7 @@ def test_train_from_stream_many_categories end def test_train_from_stream_invalid_io_type - assert_raises(StandardError) do + assert_raises(ArgumentError) do @classifier.train_from_stream(spam: Object.new) end end diff --git a/test/knn/streaming_test.rb b/test/knn/streaming_test.rb index f71885fb..f3cdd4d3 100644 --- a/test/knn/streaming_test.rb +++ b/test/knn/streaming_test.rb @@ -22,8 +22,16 @@ def test_train_from_stream_many_categories def test_train_from_stream_invalid_io_type knn = Classifier::KNN.new - assert_raises(StandardError) do - knn.train_from_stream(spam: Object.new) - end + assert_raises(ArgumentError) { knn.train_from_stream(spam: Object.new) } + end + + def test_train_from_stream_raises_without_args + knn = Classifier::KNN.new + assert_raises(ArgumentError) { knn.train_from_stream } + end + + def test_train_from_stream_raises_with_partial_args + knn = Classifier::KNN.new + assert_raises(ArgumentError) { knn.train_from_stream(:spam) } end end diff --git a/test/logistic_regression/streaming_test.rb b/test/logistic_regression/streaming_test.rb index ac06568a..2fd542b3 100644 --- a/test/logistic_regression/streaming_test.rb +++ b/test/logistic_regression/streaming_test.rb @@ -24,8 +24,16 @@ def test_train_from_stream_many_categories def test_train_from_stream_invalid_io_type classifier = Classifier::LogisticRegression.new('Spam', 'Ham') - assert_raises(StandardError) do - classifier.train_from_stream(spam: Object.new) - end + assert_raises(ArgumentError) { classifier.train_from_stream(spam: Object.new) } + end + + def test_train_from_stream_raises_without_args + classifier = Classifier::LogisticRegression.new('Spam', 'Ham') + assert_raises(ArgumentError) { classifier.train_from_stream } + end + + def test_train_from_stream_raises_with_partial_args + classifier = Classifier::LogisticRegression.new('Spam', 'Ham') + assert_raises(ArgumentError) { classifier.train_from_stream(:spam) } end end diff --git a/test/lsi/streaming_test.rb b/test/lsi/streaming_test.rb index c9327dfb..787e625e 100644 --- a/test/lsi/streaming_test.rb +++ b/test/lsi/streaming_test.rb @@ -34,10 +34,26 @@ def test_train_from_stream_many_categories assert_equal :cat, lsi.classify('independent curious pet') end + def test_train_from_stream_raises_without_args + assert_raises(ArgumentError) { @lsi.train_from_stream } + end + + def test_train_from_stream_raises_with_partial_args + assert_raises(ArgumentError) { @lsi.train_from_stream(:spam) } + end + def test_train_from_stream_invalid_io_type - assert_raises(StandardError) do - @lsi.train_from_stream(category: Object.new) + assert_raises(ArgumentError) { @lsi.train_from_stream(category: Object.new) } + end + + def test_train_from_stream_with_invalid_io_type_does_not_modify_auto_rebuild_setting + @lsi = Classifier::LSI.new(auto_rebuild: true) + + assert_raises(ArgumentError) do + @lsi.train_from_stream(cat1: StringIO.new("one\ntwo\n"), cat2: Object.new) end + + assert @lsi.auto_rebuild end def test_train_from_stream_empty_io @@ -99,6 +115,18 @@ def test_train_from_stream_rebuilds_index_when_auto_rebuild refute_predicate @lsi, :needs_rebuild? end + def test_train_from_stream_with_keyword_categories_rebuilds_index_when_auto_rebuild + @lsi = Classifier::LSI.new(auto_rebuild: true) + + @lsi.train_from_stream( + dog: StringIO.new("dogs are loyal\ndogs bark\n"), + cat: StringIO.new("cats are independent\ncats meow\n") + ) + + # Index should be built + refute_predicate @lsi, :needs_rebuild? + end + def test_train_from_stream_skips_rebuild_when_auto_rebuild_false @lsi = Classifier::LSI.new(auto_rebuild: false) @@ -108,6 +136,18 @@ def test_train_from_stream_skips_rebuild_when_auto_rebuild_false assert_predicate @lsi, :needs_rebuild? end + def test_train_from_stream_with_keyword_categories_skips_rebuild_when_auto_rebuild_false + @lsi = Classifier::LSI.new(auto_rebuild: false) + + @lsi.train_from_stream( + cat1: StringIO.new("document one\ndocument two\n"), + cat2: StringIO.new("document three\ndocument four\n") + ) + + # Index should need rebuild + assert_predicate @lsi, :needs_rebuild? + end + def test_train_from_stream_with_file Tempfile.create(['corpus', '.txt']) do |file| file.puts 'dogs are loyal pets'