From haiyangz@microsoft.com Fri May 28 16:22:49 2010 From: Haiyang Zhang Date: Fri, 28 May 2010 23:22:44 +0000 Subject: Staging: hv: Fix race condition on vmbus channel initialization Cc: Hank Janssen Message-ID: <1FB5E1D5CA062146B38059374562DF7266B8D340@TK5EX14MBXC128.redmond.corp.microsoft.com> From: Haiyang Zhang Subject: [PATCH] staging: hv: Fix race condition on vmbus channel initialization There is a possible race condition when hv_utils starts to load immediately after hv_vmbus is loading - null pointer error could happen. This patch added wait/completion to ensure all channels are ready before vmbus loading completes. So another module won't have any uninitialized channel. Signed-off-by: Haiyang Zhang Signed-off-by: Hank Janssen Signed-off-by: Greg Kroah-Hartman --- drivers/staging/hv/channel_mgmt.c | 41 +++++++++++++++++++++++++++----------- drivers/staging/hv/vmbus.h | 2 + drivers/staging/hv/vmbus_drv.c | 3 ++ 3 files changed, 35 insertions(+), 11 deletions(-) --- a/drivers/staging/hv/channel_mgmt.c +++ b/drivers/staging/hv/channel_mgmt.c @@ -23,6 +23,7 @@ #include #include #include +#include #include "osd.h" #include "logging.h" #include "vmbus_private.h" @@ -293,6 +294,25 @@ void FreeVmbusChannel(struct vmbus_chann Channel); } + +DECLARE_COMPLETION(hv_channel_ready); + +/* + * Count initialized channels, and ensure all channels are ready when hv_vmbus + * module loading completes. + */ +static void count_hv_channel(void) +{ + static int counter; + unsigned long flags; + + spin_lock_irqsave(&gVmbusConnection.channel_lock, flags); + if (++counter == MAX_MSG_TYPES) + complete(&hv_channel_ready); + spin_unlock_irqrestore(&gVmbusConnection.channel_lock, flags); +} + + /* * VmbusChannelProcessOffer - Process the offer by creating a channel/device * associated with this offer @@ -373,22 +393,21 @@ static void VmbusChannelProcessOffer(voi * can cleanup properly */ newChannel->State = CHANNEL_OPEN_STATE; - cnt = 0; - while (cnt != MAX_MSG_TYPES) { + /* Open IC channels */ + for (cnt = 0; cnt < MAX_MSG_TYPES; cnt++) { if (memcmp(&newChannel->OfferMsg.Offer.InterfaceType, &hv_cb_utils[cnt].data, - sizeof(struct hv_guid)) == 0) { + sizeof(struct hv_guid)) == 0 && + VmbusChannelOpen(newChannel, 2 * PAGE_SIZE, + 2 * PAGE_SIZE, NULL, 0, + hv_cb_utils[cnt].callback, + newChannel) == 0) { + hv_cb_utils[cnt].channel = newChannel; DPRINT_INFO(VMBUS, "%s", - hv_cb_utils[cnt].log_msg); - - if (VmbusChannelOpen(newChannel, 2 * PAGE_SIZE, - 2 * PAGE_SIZE, NULL, 0, - hv_cb_utils[cnt].callback, - newChannel) == 0) - hv_cb_utils[cnt].channel = newChannel; + hv_cb_utils[cnt].log_msg); + count_hv_channel(); } - cnt++; } } DPRINT_EXIT(VMBUS); --- a/drivers/staging/hv/vmbus.h +++ b/drivers/staging/hv/vmbus.h @@ -74,4 +74,6 @@ int vmbus_child_driver_register(struct d void vmbus_child_driver_unregister(struct driver_context *driver_ctx); void vmbus_get_interface(struct vmbus_channel_interface *interface); +extern struct completion hv_channel_ready; + #endif /* _VMBUS_H_ */ --- a/drivers/staging/hv/vmbus_drv.c +++ b/drivers/staging/hv/vmbus_drv.c @@ -27,6 +27,7 @@ #include #include #include +#include #include "version_info.h" #include "osd.h" #include "logging.h" @@ -356,6 +357,8 @@ static int vmbus_bus_init(int (*drv_init vmbus_drv_obj->GetChannelOffers(); + wait_for_completion(&hv_channel_ready); + cleanup: DPRINT_EXIT(VMBUS_DRV);