diff --git a/lib/mongo/collection.rb b/lib/mongo/collection.rb index 9f9145d..78d426a 100644 --- a/lib/mongo/collection.rb +++ b/lib/mongo/collection.rb @@ -202,13 +202,7 @@ module Mongo def insert(doc_or_docs, options={}) doc_or_docs = [doc_or_docs] unless doc_or_docs.is_a?(Array) doc_or_docs.collect! { |doc| @pk_factory.create_pk(doc) } - result = insert_documents(doc_or_docs) - if options.delete(:safe) - error = @db.error - if error - raise OperationFailure, error - end - end + result = insert_documents(doc_or_docs, @name, true, options[:safe]) result.size > 1 ? result : result.first end alias_method :<<, :insert @@ -259,11 +253,12 @@ module Mongo message.put_int(options[:upsert] ? 1 : 0) # 1 if a repsert operation (upsert) message.put_array(BSON.new.serialize(spec, false).to_a) message.put_array(BSON.new.serialize(document, false).to_a) - @db.send_message_with_operation(Mongo::Constants::OP_UPDATE, message, - "db.#{@name}.update(#{spec.inspect}, #{document.inspect})") - - if options[:safe] && error=@db.error - raise OperationFailure, error + if options[:safe] + @db.send_message_with_safe_check(Mongo::Constants::OP_UPDATE, message, + "db.#{@name}.update(#{spec.inspect}, #{document.inspect})") + else + @db.send_message_with_operation(Mongo::Constants::OP_UPDATE, message, + "db.#{@name}.update(#{spec.inspect}, #{document.inspect})") end end @@ -480,13 +475,18 @@ EOS # Sends an Mongo::Constants::OP_INSERT message to the database. # Takes an array of +documents+, an optional +collection_name+, and a # +check_keys+ setting. - def insert_documents(documents, collection_name=@name, check_keys=true) + def insert_documents(documents, collection_name=@name, check_keys=true, safe=false) message = ByteBuffer.new message.put_int(0) BSON.serialize_cstr(message, "#{@db.name}.#{collection_name}") documents.each { |doc| message.put_array(BSON.new.serialize(doc, check_keys).to_a) } - @db.send_message_with_operation(Mongo::Constants::OP_INSERT, message, - "db.#{collection_name}.insert(#{documents.inspect})") + if safe + @db.send_message_with_safe_check(Mongo::Constants::OP_INSERT, message, + "db.#{collection_name}.insert(#{documents.inspect})") + else + @db.send_message_with_operation(Mongo::Constants::OP_INSERT, message, + "db.#{collection_name}.insert(#{documents.inspect})") + end documents.collect { |o| o[:_id] || o['_id'] } end diff --git a/lib/mongo/db.rb b/lib/mongo/db.rb index 5e9e5a0..0540984 100644 --- a/lib/mongo/db.rb +++ b/lib/mongo/db.rb @@ -450,6 +450,21 @@ module Mongo end end + # Sends a message to the database, waits for a response, and raises + # and exception if the operation has failed. + def send_message_with_safe_check(operation, message, log_message=nil) + message_with_headers = add_message_headers(operation, message) + message_with_check = last_error_message + @logger.debug(" MONGODB #{log_message || message}") if @logger + @semaphore.synchronize do + send_message_on_socket(message_with_headers.append!(message_with_check).to_s) + docs, num_received, cursor_id = receive + if num_received == 1 && error = docs[0]['err'] + raise Mongo::OperationFailure, error + end + end + end + # Note: this method is a stub. Will be completed in an upcoming refactoring. def receive_message_with_operation(operation, message, log_message=nil) message_with_headers = add_message_headers(operation, message).to_s @@ -512,10 +527,10 @@ module Mongo end # Sending a message on socket. - def send_message_on_socket(message_with_headers) + def send_message_on_socket(packed_message) connect_to_master if !connected? && @auto_reconnect begin - @socket.print(message_with_headers) + @socket.print(packed_message) @socket.flush rescue => ex close @@ -583,10 +598,6 @@ module Mongo end end - def _synchronize &block - @semaphore.synchronize &block - end - def full_collection_name(collection_name) "#{@name}.#{collection_name}" end @@ -619,6 +630,35 @@ module Mongo @@current_request_id end + # Creates a getlasterror message. + def last_error_message + generate_last_error_message + end + + def generate_last_error_message + message = ByteBuffer.new + message.put_int(0) + BSON.serialize_cstr(message, "#{@name}.$cmd") + message.put_int(0) + message.put_int(-1) + message.put_array(BSON.new.serialize({:getlasterror => 1}).to_a) + add_message_headers(Mongo::Constants::OP_QUERY, message) + end + + def reset_error_message + @@reset_error_message ||= generate_reset_error_message + end + + def generate_reset_error_message + message = ByteBuffer.new + message.put_int(0) + BSON.serialize_cstr(message, "#{@name}.$cmd") + message.put_int(0) + message.put_int(-1) + message.put_array(BSON.new.serialize({:reseterror => 1}).to_a) + add_message_headers(Mongo::Constants::OP_QUERY, message) + end + def hash_password(username, plaintext) Digest::MD5.hexdigest("#{username}:mongo:#{plaintext}") end diff --git a/test/test_threading.rb b/test/test_threading.rb index 8d11b11..fb2a78c 100644 --- a/test/test_threading.rb +++ b/test/test_threading.rb @@ -1,18 +1,68 @@ -$LOAD_PATH[0,0] = File.join(File.dirname(__FILE__), '..', 'lib') -require 'mongo' -require 'test/unit' +require 'test/test_helper' class TestThreading < Test::Unit::TestCase include Mongo - @@host = ENV['MONGO_RUBY_DRIVER_HOST'] || 'localhost' - @@port = ENV['MONGO_RUBY_DRIVER_PORT'] || Connection::DEFAULT_PORT - @@db = Connection.new(@@host, @@port).db('ruby-mongo-test') + @@db = Connection.new.db('ruby-mongo-test') @@coll = @@db.collection('thread-test-collection') + def set_up_safe_data + @@db.drop_collection('duplicate') + @@db.drop_collection('unique') + @duplicate = @@db.collection('duplicate') + @unique = @@db.collection('unique') + + @duplicate.insert("test" => "insert") + @duplicate.insert("test" => "update") + @unique.insert("test" => "insert") + @unique.insert("test" => "update") + @unique.create_index("test", true) + end + + def test_safe_update + set_up_safe_data + threads = [] + 100.times do |i| + threads[i] = Thread.new do + if i % 2 == 0 + assert_raise Mongo::OperationFailure do + @unique.update({"test" => "insert"}, {"$set" => {"test" => "update"}}, :safe => true) + end + else + @duplicate.update({"test" => "insert"}, {"$set" => {"test" => "update"}}, :safe => true) + end + end + end + + 100.times do |i| + threads[i].join + end + end + + def test_safe_insert + set_up_safe_data + threads = [] + 100.times do |i| + threads[i] = Thread.new do + if i % 2 == 0 + assert_raise Mongo::OperationFailure do + @unique.insert({"test" => "insert"}, :safe => true) + end + else + @duplicate.insert({"test" => "insert"}, :safe => true) + end + end + end + + 100.times do |i| + threads[i].join + end + end + def test_threading - @@coll.remove + @@coll.drop + @@coll = @@db.collection('thread-test-collection') 1000.times do |i| @@coll.insert("x" => i) @@ -21,13 +71,13 @@ class TestThreading < Test::Unit::TestCase threads = [] 10.times do |i| - threads[i] = Thread.new{ + threads[i] = Thread.new do sum = 0 - @@coll.find().each { |document| + @@coll.find().each do |document| sum += document["x"] - } + end assert_equal 499500, sum - } + end end 10.times do |i| diff --git a/test/unit/collection_test.rb b/test/unit/collection_test.rb index cb1b792..da7506d 100644 --- a/test/unit/collection_test.rb +++ b/test/unit/collection_test.rb @@ -30,6 +30,25 @@ class CollectionTest < Test::Unit::TestCase end @coll.insert({:title => 'Moby Dick'}) end + + should "send safe update message" do + @db = MockDB.new("testing", ['localhost', 27017], :logger => @logger) + @coll = @db.collection('books') + @db.expects(:send_message_with_safe_check).with do |op, msg, log| + op == 2001 && log.include?("db.books.update") + end + @coll.update({}, {:title => 'Moby Dick'}, :safe => true) + + end + + should "send safe insert message" do + @db = MockDB.new("testing", ['localhost', 27017], :logger => @logger) + @coll = @db.collection('books') + @db.expects(:send_message_with_safe_check).with do |op, msg, log| + op == 2001 && log.include?("db.books.update") + end + @coll.update({}, {:title => 'Moby Dick'}, :safe => true) + end end end diff --git a/test/unit/db_test.rb b/test/unit/db_test.rb index 9e70da3..2622cbf 100644 --- a/test/unit/db_test.rb +++ b/test/unit/db_test.rb @@ -3,10 +3,22 @@ require 'test/test_helper' class DBTest < Test::Unit::TestCase class MockDB < DB + attr_accessor :socket def connect_to_master true end + + public :add_message_headers + end + + def insert_message(db, documents) + documents = [documents] unless documents.is_a?(Array) + message = ByteBuffer.new + message.put_int(0) + BSON.serialize_cstr(message, "#{db.name}.test") + documents.each { |doc| message.put_array(BSON.new.serialize(doc, true).to_a) } + message = db.add_message_headers(Mongo::Constants::OP_INSERT, message) end context "DB commands" do @@ -47,6 +59,24 @@ class DBTest < Test::Unit::TestCase end end + context "safe messages" do + setup do + @db = MockDB.new("testing", ['localhost', 27017]) + @collection = mock() + @db.stubs(:system_command_collection).returns(@collection) + end + + should "receive getlasterror message" do + @socket = mock() + @socket.stubs(:close) + @socket.expects(:flush) + @socket.expects(:print).with { |message| message.include?('getlasterror') } + @db.socket = @socket + @db.stubs(:receive) + message = insert_message(@db, {:a => 1}) + @db.send_message_with_safe_check(Mongo::Constants::OP_QUERY, message) + end + end end