From fec8755b6193c93a935423fdd6026b354aa2e15e Mon Sep 17 00:00:00 2001 From: Hank Janssen Date: Mon, 13 Jul 2009 15:34:54 -0700 Subject: Staging: hv: add the Hyper-V virtual network driver From: Hank Janssen This is the virtual network driver when running Linux on top of Hyper-V. Signed-off-by: Hank Janssen Signed-off-by: Haiyang Zhang Signed-off-by: Greg Kroah-Hartman --- drivers/staging/hv/NetVsc.c | 1499 +++++++++++++++++++++++++++++++++++++++ drivers/staging/hv/NetVsc.h | 91 ++ drivers/staging/hv/RndisFilter.c | 1162 ++++++++++++++++++++++++++++++ drivers/staging/hv/RndisFilter.h | 61 + drivers/staging/hv/netvsc_drv.c | 720 ++++++++++++++++++ 5 files changed, 3533 insertions(+) create mode 100644 drivers/staging/hv/netvsc.c --- /dev/null +++ b/drivers/staging/hv/NetVsc.c @@ -0,0 +1,1499 @@ +/* + * + * Copyright (c) 2009, Microsoft Corporation. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along with + * this program; if not, write to the Free Software Foundation, Inc., 59 Temple + * Place - Suite 330, Boston, MA 02111-1307 USA. + * + * Authors: + * Hank Janssen + * + */ + + +#include "logging.h" +#include "NetVsc.h" +#include "RndisFilter.h" + + +// +// Globals +// +static const char* gDriverName="netvsc"; + +// {F8615163-DF3E-46c5-913F-F2D2F965ED0E} +static const GUID gNetVscDeviceType={ + .Data = {0x63, 0x51, 0x61, 0xF8, 0x3E, 0xDF, 0xc5, 0x46, 0x91, 0x3F, 0xF2, 0xD2, 0xF9, 0x65, 0xED, 0x0E} +}; + + +// +// Internal routines +// +static int +NetVscOnDeviceAdd( + DEVICE_OBJECT *Device, + void *AdditionalInfo + ); + +static int +NetVscOnDeviceRemove( + DEVICE_OBJECT *Device + ); + +static void +NetVscOnCleanup( + DRIVER_OBJECT *Driver + ); + +static void +NetVscOnChannelCallback( + PVOID context + ); + +static int +NetVscInitializeSendBufferWithNetVsp( + DEVICE_OBJECT *Device + ); + +static int +NetVscInitializeReceiveBufferWithNetVsp( + DEVICE_OBJECT *Device + ); + +static int +NetVscDestroySendBuffer( + NETVSC_DEVICE *NetDevice + ); + +static int +NetVscDestroyReceiveBuffer( + NETVSC_DEVICE *NetDevice + ); + +static int +NetVscConnectToVsp( + DEVICE_OBJECT *Device + ); + +static void +NetVscOnSendCompletion( + DEVICE_OBJECT *Device, + VMPACKET_DESCRIPTOR *Packet + ); + +static int +NetVscOnSend( + DEVICE_OBJECT *Device, + NETVSC_PACKET *Packet + ); + +static void +NetVscOnReceive( + DEVICE_OBJECT *Device, + VMPACKET_DESCRIPTOR *Packet + ); + +static void +NetVscOnReceiveCompletion( + PVOID Context + ); + +static void +NetVscSendReceiveCompletion( + DEVICE_OBJECT *Device, + UINT64 TransactionId + ); + +static inline NETVSC_DEVICE* AllocNetDevice(DEVICE_OBJECT *Device) +{ + NETVSC_DEVICE *netDevice; + + netDevice = MemAllocZeroed(sizeof(NETVSC_DEVICE)); + if (!netDevice) + return NULL; + + // Set to 2 to allow both inbound and outbound traffic + InterlockedCompareExchange(&netDevice->RefCount, 2, 0); + + netDevice->Device = Device; + Device->Extension = netDevice; + + return netDevice; +} + +static inline void FreeNetDevice(NETVSC_DEVICE *Device) +{ + ASSERT(Device->RefCount == 0); + Device->Device->Extension = NULL; + MemFree(Device); +} + + +// Get the net device object iff exists and its refcount > 1 +static inline NETVSC_DEVICE* GetOutboundNetDevice(DEVICE_OBJECT *Device) +{ + NETVSC_DEVICE *netDevice; + + netDevice = (NETVSC_DEVICE*)Device->Extension; + if (netDevice && netDevice->RefCount > 1) + { + InterlockedIncrement(&netDevice->RefCount); + } + else + { + netDevice = NULL; + } + + return netDevice; +} + +// Get the net device object iff exists and its refcount > 0 +static inline NETVSC_DEVICE* GetInboundNetDevice(DEVICE_OBJECT *Device) +{ + NETVSC_DEVICE *netDevice; + + netDevice = (NETVSC_DEVICE*)Device->Extension; + if (netDevice && netDevice->RefCount) + { + InterlockedIncrement(&netDevice->RefCount); + } + else + { + netDevice = NULL; + } + + return netDevice; +} + +static inline void PutNetDevice(DEVICE_OBJECT *Device) +{ + NETVSC_DEVICE *netDevice; + + netDevice = (NETVSC_DEVICE*)Device->Extension; + ASSERT(netDevice); + + InterlockedDecrement(&netDevice->RefCount); +} + +static inline NETVSC_DEVICE* ReleaseOutboundNetDevice(DEVICE_OBJECT *Device) +{ + NETVSC_DEVICE *netDevice; + + netDevice = (NETVSC_DEVICE*)Device->Extension; + if (netDevice == NULL) + return NULL; + + // Busy wait until the ref drop to 2, then set it to 1 + while (InterlockedCompareExchange(&netDevice->RefCount, 1, 2) != 2) + { + Sleep(100); + } + + return netDevice; +} + +static inline NETVSC_DEVICE* ReleaseInboundNetDevice(DEVICE_OBJECT *Device) +{ + NETVSC_DEVICE *netDevice; + + netDevice = (NETVSC_DEVICE*)Device->Extension; + if (netDevice == NULL) + return NULL; + + // Busy wait until the ref drop to 1, then set it to 0 + while (InterlockedCompareExchange(&netDevice->RefCount, 0, 1) != 1) + { + Sleep(100); + } + + Device->Extension = NULL; + return netDevice; +} + +/*++; + + +Name: + NetVscInitialize() + +Description: + Main entry point + +--*/ +int +NetVscInitialize( + DRIVER_OBJECT *drv + ) +{ + NETVSC_DRIVER_OBJECT* driver = (NETVSC_DRIVER_OBJECT*)drv; + int ret=0; + + DPRINT_ENTER(NETVSC); + + DPRINT_DBG(NETVSC, "sizeof(NETVSC_PACKET)=%d, sizeof(NVSP_MESSAGE)=%d, sizeof(VMTRANSFER_PAGE_PACKET_HEADER)=%d", + sizeof(NETVSC_PACKET), sizeof(NVSP_MESSAGE), sizeof(VMTRANSFER_PAGE_PACKET_HEADER)); + + // Make sure we are at least 2 pages since 1 page is used for control + ASSERT(driver->RingBufferSize >= (PAGE_SIZE << 1)); + + drv->name = gDriverName; + memcpy(&drv->deviceType, &gNetVscDeviceType, sizeof(GUID)); + + // Make sure it is set by the caller + ASSERT(driver->OnReceiveCallback); + ASSERT(driver->OnLinkStatusChanged); + + // Setup the dispatch table + driver->Base.OnDeviceAdd = NetVscOnDeviceAdd; + driver->Base.OnDeviceRemove = NetVscOnDeviceRemove; + driver->Base.OnCleanup = NetVscOnCleanup; + + driver->OnSend = NetVscOnSend; + + RndisFilterInit(driver); + + DPRINT_EXIT(NETVSC); + + return ret; +} + +static int +NetVscInitializeReceiveBufferWithNetVsp( + DEVICE_OBJECT *Device + ) +{ + int ret=0; + NETVSC_DEVICE *netDevice; + NVSP_MESSAGE *initPacket; + + DPRINT_ENTER(NETVSC); + + netDevice = GetOutboundNetDevice(Device); + if (!netDevice) + { + DPRINT_ERR(NETVSC, "unable to get net device...device being destroyed?"); + DPRINT_EXIT(NETVSC); + return -1; + } + ASSERT(netDevice->ReceiveBufferSize > 0); + ASSERT((netDevice->ReceiveBufferSize & (PAGE_SIZE-1)) == 0); // page-size grandularity + + netDevice->ReceiveBuffer = PageAlloc(netDevice->ReceiveBufferSize >> PAGE_SHIFT); + if (!netDevice->ReceiveBuffer) + { + DPRINT_ERR(NETVSC, "unable to allocate receive buffer of size %d", netDevice->ReceiveBufferSize); + ret = -1; + goto Cleanup; + } + ASSERT(((ULONG_PTR)netDevice->ReceiveBuffer & (PAGE_SIZE-1)) == 0); // page-aligned buffer + + DPRINT_INFO(NETVSC, "Establishing receive buffer's GPADL..."); + + // Establish the gpadl handle for this buffer on this channel. + // Note: This call uses the vmbus connection rather than the channel to establish + // the gpadl handle. + ret = Device->Driver->VmbusChannelInterface.EstablishGpadl(Device, + netDevice->ReceiveBuffer, + netDevice->ReceiveBufferSize, + &netDevice->ReceiveBufferGpadlHandle); + + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to establish receive buffer's gpadl"); + goto Cleanup; + } + + //WaitEventWait(ext->ChannelInitEvent); + + // Notify the NetVsp of the gpadl handle + DPRINT_INFO(NETVSC, "Sending NvspMessage1TypeSendReceiveBuffer..."); + + initPacket = &netDevice->ChannelInitPacket; + + memset(initPacket, 0, sizeof(NVSP_MESSAGE)); + + initPacket->Header.MessageType = NvspMessage1TypeSendReceiveBuffer; + initPacket->Messages.Version1Messages.SendReceiveBuffer.GpadlHandle = netDevice->ReceiveBufferGpadlHandle; + initPacket->Messages.Version1Messages.SendReceiveBuffer.Id = NETVSC_RECEIVE_BUFFER_ID; + + // Send the gpadl notification request + ret = Device->Driver->VmbusChannelInterface.SendPacket(Device, + initPacket, + sizeof(NVSP_MESSAGE), + (ULONG_PTR)initPacket, + VmbusPacketTypeDataInBand, + VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED); + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to send receive buffer's gpadl to netvsp"); + goto Cleanup; + } + + WaitEventWait(netDevice->ChannelInitEvent); + + // Check the response + if (initPacket->Messages.Version1Messages.SendReceiveBufferComplete.Status != NvspStatusSuccess) + { + DPRINT_ERR(NETVSC, + "Unable to complete receive buffer initialzation with NetVsp - status %d", + initPacket->Messages.Version1Messages.SendReceiveBufferComplete.Status); + ret = -1; + goto Cleanup; + } + + // Parse the response + ASSERT(netDevice->ReceiveSectionCount == 0); + ASSERT(netDevice->ReceiveSections == NULL); + + netDevice->ReceiveSectionCount = initPacket->Messages.Version1Messages.SendReceiveBufferComplete.NumSections; + + netDevice->ReceiveSections = MemAlloc(netDevice->ReceiveSectionCount * sizeof(NVSP_1_RECEIVE_BUFFER_SECTION)); + if (netDevice->ReceiveSections == NULL) + { + ret = -1; + goto Cleanup; + } + + memcpy(netDevice->ReceiveSections, + initPacket->Messages.Version1Messages.SendReceiveBufferComplete.Sections, + netDevice->ReceiveSectionCount * sizeof(NVSP_1_RECEIVE_BUFFER_SECTION)); + + DPRINT_INFO(NETVSC, + "Receive sections info (count %d, offset %d, endoffset %d, suballoc size %d, num suballocs %d)", + netDevice->ReceiveSectionCount, netDevice->ReceiveSections[0].Offset, netDevice->ReceiveSections[0].EndOffset, + netDevice->ReceiveSections[0].SubAllocationSize, netDevice->ReceiveSections[0].NumSubAllocations); + + + //For 1st release, there should only be 1 section that represents the entire receive buffer + if (netDevice->ReceiveSectionCount != 1 || + netDevice->ReceiveSections->Offset != 0 ) + { + ret = -1; + goto Cleanup; + } + + goto Exit; + +Cleanup: + NetVscDestroyReceiveBuffer(netDevice); + +Exit: + PutNetDevice(Device); + DPRINT_EXIT(NETVSC); + return ret; +} + + +static int +NetVscInitializeSendBufferWithNetVsp( + DEVICE_OBJECT *Device + ) +{ + int ret=0; + NETVSC_DEVICE *netDevice; + NVSP_MESSAGE *initPacket; + + DPRINT_ENTER(NETVSC); + + netDevice = GetOutboundNetDevice(Device); + if (!netDevice) + { + DPRINT_ERR(NETVSC, "unable to get net device...device being destroyed?"); + DPRINT_EXIT(NETVSC); + return -1; + } + ASSERT(netDevice->SendBufferSize > 0); + ASSERT((netDevice->SendBufferSize & (PAGE_SIZE-1)) == 0); // page-size grandularity + + netDevice->SendBuffer = PageAlloc(netDevice->SendBufferSize >> PAGE_SHIFT); + if (!netDevice->SendBuffer) + { + DPRINT_ERR(NETVSC, "unable to allocate send buffer of size %d", netDevice->SendBufferSize); + ret = -1; + goto Cleanup; + } + ASSERT(((ULONG_PTR)netDevice->SendBuffer & (PAGE_SIZE-1)) == 0); // page-aligned buffer + + DPRINT_INFO(NETVSC, "Establishing send buffer's GPADL..."); + + // Establish the gpadl handle for this buffer on this channel. + // Note: This call uses the vmbus connection rather than the channel to establish + // the gpadl handle. + ret = Device->Driver->VmbusChannelInterface.EstablishGpadl(Device, + netDevice->SendBuffer, + netDevice->SendBufferSize, + &netDevice->SendBufferGpadlHandle); + + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to establish send buffer's gpadl"); + goto Cleanup; + } + + //WaitEventWait(ext->ChannelInitEvent); + + // Notify the NetVsp of the gpadl handle + DPRINT_INFO(NETVSC, "Sending NvspMessage1TypeSendSendBuffer..."); + + initPacket = &netDevice->ChannelInitPacket; + + memset(initPacket, 0, sizeof(NVSP_MESSAGE)); + + initPacket->Header.MessageType = NvspMessage1TypeSendSendBuffer; + initPacket->Messages.Version1Messages.SendReceiveBuffer.GpadlHandle = netDevice->SendBufferGpadlHandle; + initPacket->Messages.Version1Messages.SendReceiveBuffer.Id = NETVSC_SEND_BUFFER_ID; + + // Send the gpadl notification request + ret = Device->Driver->VmbusChannelInterface.SendPacket(Device, + initPacket, + sizeof(NVSP_MESSAGE), + (ULONG_PTR)initPacket, + VmbusPacketTypeDataInBand, + VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED); + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to send receive buffer's gpadl to netvsp"); + goto Cleanup; + } + + WaitEventWait(netDevice->ChannelInitEvent); + + // Check the response + if (initPacket->Messages.Version1Messages.SendSendBufferComplete.Status != NvspStatusSuccess) + { + DPRINT_ERR(NETVSC, + "Unable to complete send buffer initialzation with NetVsp - status %d", + initPacket->Messages.Version1Messages.SendSendBufferComplete.Status); + ret = -1; + goto Cleanup; + } + + netDevice->SendSectionSize = initPacket->Messages.Version1Messages.SendSendBufferComplete.SectionSize; + + goto Exit; + +Cleanup: + NetVscDestroySendBuffer(netDevice); + +Exit: + PutNetDevice(Device); + DPRINT_EXIT(NETVSC); + return ret; +} + +static int +NetVscDestroyReceiveBuffer( + NETVSC_DEVICE *NetDevice + ) +{ + NVSP_MESSAGE *revokePacket; + int ret=0; + + + DPRINT_ENTER(NETVSC); + + // If we got a section count, it means we received a SendReceiveBufferComplete msg + // (ie sent NvspMessage1TypeSendReceiveBuffer msg) therefore, we need to send a revoke msg here + if (NetDevice->ReceiveSectionCount) + { + DPRINT_INFO(NETVSC, "Sending NvspMessage1TypeRevokeReceiveBuffer..."); + + // Send the revoke receive buffer + revokePacket = &NetDevice->RevokePacket; + memset(revokePacket, 0, sizeof(NVSP_MESSAGE)); + + revokePacket->Header.MessageType = NvspMessage1TypeRevokeReceiveBuffer; + revokePacket->Messages.Version1Messages.RevokeReceiveBuffer.Id = NETVSC_RECEIVE_BUFFER_ID; + + ret = NetDevice->Device->Driver->VmbusChannelInterface.SendPacket(NetDevice->Device, + revokePacket, + sizeof(NVSP_MESSAGE), + (ULONG_PTR)revokePacket, + VmbusPacketTypeDataInBand, + 0); + // If we failed here, we might as well return and have a leak rather than continue and a bugchk + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to send revoke receive buffer to netvsp"); + DPRINT_EXIT(NETVSC); + return -1; + } + } + + // Teardown the gpadl on the vsp end + if (NetDevice->ReceiveBufferGpadlHandle) + { + DPRINT_INFO(NETVSC, "Tearing down receive buffer's GPADL..."); + + ret = NetDevice->Device->Driver->VmbusChannelInterface.TeardownGpadl(NetDevice->Device, + NetDevice->ReceiveBufferGpadlHandle); + + // If we failed here, we might as well return and have a leak rather than continue and a bugchk + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to teardown receive buffer's gpadl"); + DPRINT_EXIT(NETVSC); + return -1; + } + NetDevice->ReceiveBufferGpadlHandle = 0; + } + + if (NetDevice->ReceiveBuffer) + { + DPRINT_INFO(NETVSC, "Freeing up receive buffer..."); + + // Free up the receive buffer + PageFree(NetDevice->ReceiveBuffer, NetDevice->ReceiveBufferSize >> PAGE_SHIFT); + NetDevice->ReceiveBuffer = NULL; + } + + if (NetDevice->ReceiveSections) + { + MemFree(NetDevice->ReceiveSections); + NetDevice->ReceiveSections = NULL; + NetDevice->ReceiveSectionCount = 0; + } + + DPRINT_EXIT(NETVSC); + + return ret; +} + + + + +static int +NetVscDestroySendBuffer( + NETVSC_DEVICE *NetDevice + ) +{ + NVSP_MESSAGE *revokePacket; + int ret=0; + + + DPRINT_ENTER(NETVSC); + + // If we got a section count, it means we received a SendReceiveBufferComplete msg + // (ie sent NvspMessage1TypeSendReceiveBuffer msg) therefore, we need to send a revoke msg here + if (NetDevice->SendSectionSize) + { + DPRINT_INFO(NETVSC, "Sending NvspMessage1TypeRevokeSendBuffer..."); + + // Send the revoke send buffer + revokePacket = &NetDevice->RevokePacket; + memset(revokePacket, 0, sizeof(NVSP_MESSAGE)); + + revokePacket->Header.MessageType = NvspMessage1TypeRevokeSendBuffer; + revokePacket->Messages.Version1Messages.RevokeSendBuffer.Id = NETVSC_SEND_BUFFER_ID; + + ret = NetDevice->Device->Driver->VmbusChannelInterface.SendPacket(NetDevice->Device, + revokePacket, + sizeof(NVSP_MESSAGE), + (ULONG_PTR)revokePacket, + VmbusPacketTypeDataInBand, + 0); + // If we failed here, we might as well return and have a leak rather than continue and a bugchk + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to send revoke send buffer to netvsp"); + DPRINT_EXIT(NETVSC); + return -1; + } + } + + // Teardown the gpadl on the vsp end + if (NetDevice->SendBufferGpadlHandle) + { + DPRINT_INFO(NETVSC, "Tearing down send buffer's GPADL..."); + + ret = NetDevice->Device->Driver->VmbusChannelInterface.TeardownGpadl(NetDevice->Device, + NetDevice->SendBufferGpadlHandle); + + // If we failed here, we might as well return and have a leak rather than continue and a bugchk + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to teardown send buffer's gpadl"); + DPRINT_EXIT(NETVSC); + return -1; + } + NetDevice->SendBufferGpadlHandle = 0; + } + + if (NetDevice->SendBuffer) + { + DPRINT_INFO(NETVSC, "Freeing up send buffer..."); + + // Free up the receive buffer + PageFree(NetDevice->SendBuffer, NetDevice->SendBufferSize >> PAGE_SHIFT); + NetDevice->SendBuffer = NULL; + } + + DPRINT_EXIT(NETVSC); + + return ret; +} + + + +static int +NetVscConnectToVsp( + DEVICE_OBJECT *Device + ) +{ + int ret=0; + NETVSC_DEVICE *netDevice; + NVSP_MESSAGE *initPacket; + int ndisVersion; + + DPRINT_ENTER(NETVSC); + + netDevice = GetOutboundNetDevice(Device); + if (!netDevice) + { + DPRINT_ERR(NETVSC, "unable to get net device...device being destroyed?"); + DPRINT_EXIT(NETVSC); + return -1; + } + + initPacket = &netDevice->ChannelInitPacket; + + memset(initPacket, 0, sizeof(NVSP_MESSAGE)); + initPacket->Header.MessageType = NvspMessageTypeInit; + initPacket->Messages.InitMessages.Init.MinProtocolVersion = NVSP_MIN_PROTOCOL_VERSION; + initPacket->Messages.InitMessages.Init.MaxProtocolVersion = NVSP_MAX_PROTOCOL_VERSION; + + DPRINT_INFO(NETVSC, "Sending NvspMessageTypeInit..."); + + // Send the init request + ret = Device->Driver->VmbusChannelInterface.SendPacket(Device, + initPacket, + sizeof(NVSP_MESSAGE), + (ULONG_PTR)initPacket, + VmbusPacketTypeDataInBand, + VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED); + + if( ret != 0) + { + DPRINT_ERR(NETVSC, "unable to send NvspMessageTypeInit"); + goto Cleanup; + } + + WaitEventWait(netDevice->ChannelInitEvent); + + // Now, check the response + //ASSERT(initPacket->Messages.InitMessages.InitComplete.MaximumMdlChainLength <= MAX_MULTIPAGE_BUFFER_COUNT); + DPRINT_INFO(NETVSC, "NvspMessageTypeInit status(%d) max mdl chain (%d)", + initPacket->Messages.InitMessages.InitComplete.Status, + initPacket->Messages.InitMessages.InitComplete.MaximumMdlChainLength); + + if (initPacket->Messages.InitMessages.InitComplete.Status != NvspStatusSuccess) + { + DPRINT_ERR(NETVSC, "unable to initialize with netvsp (status 0x%x)", initPacket->Messages.InitMessages.InitComplete.Status); + ret = -1; + goto Cleanup; + } + + if (initPacket->Messages.InitMessages.InitComplete.NegotiatedProtocolVersion != NVSP_PROTOCOL_VERSION_1) + { + DPRINT_ERR(NETVSC, "unable to initialize with netvsp (version expected 1 got %d)", + initPacket->Messages.InitMessages.InitComplete.NegotiatedProtocolVersion); + ret = -1; + goto Cleanup; + } + DPRINT_INFO(NETVSC, "Sending NvspMessage1TypeSendNdisVersion..."); + + // Send the ndis version + memset(initPacket, 0, sizeof(NVSP_MESSAGE)); + + ndisVersion = 0x00050000; + + initPacket->Header.MessageType = NvspMessage1TypeSendNdisVersion; + initPacket->Messages.Version1Messages.SendNdisVersion.NdisMajorVersion = (ndisVersion & 0xFFFF0000) >> 16; + initPacket->Messages.Version1Messages.SendNdisVersion.NdisMinorVersion = ndisVersion & 0xFFFF; + + // Send the init request + ret = Device->Driver->VmbusChannelInterface.SendPacket(Device, + initPacket, + sizeof(NVSP_MESSAGE), + (ULONG_PTR)initPacket, + VmbusPacketTypeDataInBand, + 0); + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to send NvspMessage1TypeSendNdisVersion"); + ret = -1; + goto Cleanup; + } + // + // BUGBUG - We have to wait for the above msg since the netvsp uses KMCL which acknowledges packet (completion packet) + // since our Vmbus always set the VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED flag + //WaitEventWait(NetVscChannel->ChannelInitEvent); + + // Post the big receive buffer to NetVSP + ret = NetVscInitializeReceiveBufferWithNetVsp(Device); + if (ret == 0) + { + ret = NetVscInitializeSendBufferWithNetVsp(Device); + } + +Cleanup: + PutNetDevice(Device); + DPRINT_EXIT(NETVSC); + return ret; +} + +static void +NetVscDisconnectFromVsp( + NETVSC_DEVICE *NetDevice + ) +{ + DPRINT_ENTER(NETVSC); + + NetVscDestroyReceiveBuffer(NetDevice); + NetVscDestroySendBuffer(NetDevice); + + DPRINT_EXIT(NETVSC); +} + + +/*++ + +Name: + NetVscOnDeviceAdd() + +Description: + Callback when the device belonging to this driver is added + +--*/ +int +NetVscOnDeviceAdd( + DEVICE_OBJECT *Device, + void *AdditionalInfo + ) +{ + int ret=0; + int i; + + NETVSC_DEVICE* netDevice; + NETVSC_PACKET* packet; + LIST_ENTRY *entry; + + NETVSC_DRIVER_OBJECT *netDriver = (NETVSC_DRIVER_OBJECT*) Device->Driver;; + + DPRINT_ENTER(NETVSC); + + netDevice = AllocNetDevice(Device); + if (!netDevice) + { + ret = -1; + goto Cleanup; + } + + DPRINT_DBG(NETVSC, "netvsc channel object allocated - %p", netDevice); + + // Initialize the NetVSC channel extension + netDevice->ReceiveBufferSize = NETVSC_RECEIVE_BUFFER_SIZE; + netDevice->ReceivePacketListLock = SpinlockCreate(); + + netDevice->SendBufferSize = NETVSC_SEND_BUFFER_SIZE; + + INITIALIZE_LIST_HEAD(&netDevice->ReceivePacketList); + + for (i=0; i < NETVSC_RECEIVE_PACKETLIST_COUNT; i++) + { + packet = MemAllocZeroed(sizeof(NETVSC_PACKET) + (NETVSC_RECEIVE_SG_COUNT* sizeof(PAGE_BUFFER))); + if (!packet) + { + DPRINT_DBG(NETVSC, "unable to allocate netvsc pkts for receive pool (wanted %d got %d)", NETVSC_RECEIVE_PACKETLIST_COUNT, i); + break; + } + + INSERT_TAIL_LIST(&netDevice->ReceivePacketList, &packet->ListEntry); + } + netDevice->ChannelInitEvent = WaitEventCreate(); + + // Open the channel + ret = Device->Driver->VmbusChannelInterface.Open(Device, + netDriver->RingBufferSize, + netDriver->RingBufferSize, + NULL, 0, + NetVscOnChannelCallback, + Device + ); + + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to open channel: %d", ret); + ret = -1; + goto Cleanup; + } + + // Channel is opened + DPRINT_INFO(NETVSC, "*** NetVSC channel opened successfully! ***"); + + // Connect with the NetVsp + ret = NetVscConnectToVsp(Device); + if (ret != 0) + { + DPRINT_ERR(NETVSC, "unable to connect to NetVSP - %d", ret); + ret = -1; + goto Close; + } + + DPRINT_INFO(NETVSC, "*** NetVSC channel handshake result - %d ***", ret); + + DPRINT_EXIT(NETVSC); + return ret; + +Close: + // Now, we can close the channel safely + Device->Driver->VmbusChannelInterface.Close(Device); + +Cleanup: + + if (netDevice) + { + WaitEventClose(netDevice->ChannelInitEvent); + + while (!IsListEmpty(&netDevice->ReceivePacketList)) + { + entry = REMOVE_HEAD_LIST(&netDevice->ReceivePacketList); + packet = CONTAINING_RECORD(entry, NETVSC_PACKET, ListEntry); + MemFree(packet); + } + + SpinlockClose(netDevice->ReceivePacketListLock); + + ReleaseOutboundNetDevice(Device); + ReleaseInboundNetDevice(Device); + + FreeNetDevice(netDevice); + } + + DPRINT_EXIT(NETVSC); + return ret; +} + + +/*++ + +Name: + NetVscOnDeviceRemove() + +Description: + Callback when the root bus device is removed + +--*/ +int +NetVscOnDeviceRemove( + DEVICE_OBJECT *Device + ) +{ + NETVSC_DEVICE *netDevice; + NETVSC_PACKET *netvscPacket; + int ret=0; + LIST_ENTRY *entry; + + DPRINT_ENTER(NETVSC); + + DPRINT_INFO(NETVSC, "Disabling outbound traffic on net device (%p)...", Device->Extension); + + // Stop outbound traffic ie sends and receives completions + netDevice = ReleaseOutboundNetDevice(Device); + if (!netDevice) + { + DPRINT_ERR(NETVSC, "No net device present!!"); + return -1; + } + + // Wait for all send completions + while (netDevice->NumOutstandingSends) + { + DPRINT_INFO(NETVSC, "waiting for %d requests to complete...", netDevice->NumOutstandingSends); + + Sleep(100); + } + + DPRINT_INFO(NETVSC, "Disconnecting from netvsp..."); + + NetVscDisconnectFromVsp(netDevice); + + DPRINT_INFO(NETVSC, "Disabling inbound traffic on net device (%p)...", Device->Extension); + + // Stop inbound traffic ie receives and sends completions + netDevice = ReleaseInboundNetDevice(Device); + + // At this point, no one should be accessing netDevice except in here + DPRINT_INFO(NETVSC, "net device (%p) safe to remove", netDevice); + + // Now, we can close the channel safely + Device->Driver->VmbusChannelInterface.Close(Device); + + // Release all resources + while (!IsListEmpty(&netDevice->ReceivePacketList)) + { + entry = REMOVE_HEAD_LIST(&netDevice->ReceivePacketList); + netvscPacket = CONTAINING_RECORD(entry, NETVSC_PACKET, ListEntry); + + MemFree(netvscPacket); + } + + SpinlockClose(netDevice->ReceivePacketListLock); + WaitEventClose(netDevice->ChannelInitEvent); + FreeNetDevice(netDevice); + + DPRINT_EXIT(NETVSC); + return ret; +} + + + +/*++ + +Name: + NetVscOnCleanup() + +Description: + Perform any cleanup when the driver is removed + +--*/ +void +NetVscOnCleanup( + DRIVER_OBJECT *drv + ) +{ + DPRINT_ENTER(NETVSC); + + DPRINT_EXIT(NETVSC); +} + +static void +NetVscOnSendCompletion( + DEVICE_OBJECT *Device, + VMPACKET_DESCRIPTOR *Packet + ) +{ + NETVSC_DEVICE* netDevice; + NVSP_MESSAGE *nvspPacket; + NETVSC_PACKET *nvscPacket; + + DPRINT_ENTER(NETVSC); + + netDevice = GetInboundNetDevice(Device); + if (!netDevice) + { + DPRINT_ERR(NETVSC, "unable to get net device...device being destroyed?"); + DPRINT_EXIT(NETVSC); + return; + } + + nvspPacket = (NVSP_MESSAGE*)((ULONG_PTR)Packet + (Packet->DataOffset8 << 3)); + + DPRINT_DBG(NETVSC, "send completion packet - type %d", nvspPacket->Header.MessageType); + + if (nvspPacket->Header.MessageType == NvspMessageTypeInitComplete || + nvspPacket->Header.MessageType == NvspMessage1TypeSendReceiveBufferComplete || + nvspPacket->Header.MessageType == NvspMessage1TypeSendSendBufferComplete) + { + // Copy the response back + memcpy(&netDevice->ChannelInitPacket, nvspPacket, sizeof(NVSP_MESSAGE)); + WaitEventSet(netDevice->ChannelInitEvent); + } + else if (nvspPacket->Header.MessageType == NvspMessage1TypeSendRNDISPacketComplete) + { + // Get the send context + nvscPacket = (NETVSC_PACKET *)(ULONG_PTR)Packet->TransactionId; + ASSERT(nvscPacket); + + // Notify the layer above us + nvscPacket->Completion.Send.OnSendCompletion(nvscPacket->Completion.Send.SendCompletionContext); + + InterlockedDecrement(&netDevice->NumOutstandingSends); + } + else + { + DPRINT_ERR(NETVSC, "Unknown send completion packet type - %d received!!", nvspPacket->Header.MessageType); + } + + PutNetDevice(Device); + DPRINT_EXIT(NETVSC); +} + + + +static int +NetVscOnSend( + DEVICE_OBJECT *Device, + NETVSC_PACKET *Packet + ) +{ + NETVSC_DEVICE* netDevice; + int ret=0; + + NVSP_MESSAGE sendMessage; + + DPRINT_ENTER(NETVSC); + + netDevice = GetOutboundNetDevice(Device); + if (!netDevice) + { + DPRINT_ERR(NETVSC, "net device (%p) shutting down...ignoring outbound packets", netDevice); + DPRINT_EXIT(NETVSC); + return -2; + } + + sendMessage.Header.MessageType = NvspMessage1TypeSendRNDISPacket; + if (Packet->IsDataPacket) + sendMessage.Messages.Version1Messages.SendRNDISPacket.ChannelType = 0;// 0 is RMC_DATA; + else + sendMessage.Messages.Version1Messages.SendRNDISPacket.ChannelType = 1;// 1 is RMC_CONTROL; + + // Not using send buffer section + sendMessage.Messages.Version1Messages.SendRNDISPacket.SendBufferSectionIndex = 0xFFFFFFFF; + sendMessage.Messages.Version1Messages.SendRNDISPacket.SendBufferSectionSize = 0; + + if (Packet->PageBufferCount) + { + ret = Device->Driver->VmbusChannelInterface.SendPacketPageBuffer(Device, + Packet->PageBuffers, + Packet->PageBufferCount, + &sendMessage, + sizeof(NVSP_MESSAGE), + (ULONG_PTR)Packet); + } + else + { + ret = Device->Driver->VmbusChannelInterface.SendPacket(Device, + &sendMessage, + sizeof(NVSP_MESSAGE), + (ULONG_PTR)Packet, + VmbusPacketTypeDataInBand, + VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED); + + } + + if (ret != 0) + { + DPRINT_ERR(NETVSC, "Unable to send packet %p ret %d", Packet, ret); + } + + InterlockedIncrement(&netDevice->NumOutstandingSends); + PutNetDevice(Device); + + DPRINT_EXIT(NETVSC); + return ret; +} + + +static void +NetVscOnReceive( + DEVICE_OBJECT *Device, + VMPACKET_DESCRIPTOR *Packet + ) +{ + NETVSC_DEVICE* netDevice; + VMTRANSFER_PAGE_PACKET_HEADER *vmxferpagePacket; + NVSP_MESSAGE *nvspPacket; + NETVSC_PACKET *netvscPacket=NULL; + LIST_ENTRY* entry; + ULONG_PTR start; + ULONG_PTR end, endVirtual; + //NETVSC_DRIVER_OBJECT *netvscDriver; + XFERPAGE_PACKET *xferpagePacket=NULL; + LIST_ENTRY listHead; + + int i=0, j=0; + int count=0, bytesRemain=0; + + DPRINT_ENTER(NETVSC); + + netDevice = GetInboundNetDevice(Device); + if (!netDevice) + { + DPRINT_ERR(NETVSC, "unable to get net device...device being destroyed?"); + DPRINT_EXIT(NETVSC); + return; + } + + // All inbound packets other than send completion should be xfer page packet + if (Packet->Type != VmbusPacketTypeDataUsingTransferPages) + { + DPRINT_ERR(NETVSC, "Unknown packet type received - %d", Packet->Type); + PutNetDevice(Device); + return; + } + + nvspPacket = (NVSP_MESSAGE*)((ULONG_PTR)Packet + (Packet->DataOffset8 << 3)); + + // Make sure this is a valid nvsp packet + if (nvspPacket->Header.MessageType != NvspMessage1TypeSendRNDISPacket ) + { + DPRINT_ERR(NETVSC, "Unknown nvsp packet type received - %d", nvspPacket->Header.MessageType); + PutNetDevice(Device); + return; + } + + DPRINT_DBG(NETVSC, "NVSP packet received - type %d", nvspPacket->Header.MessageType); + + vmxferpagePacket = (VMTRANSFER_PAGE_PACKET_HEADER*)Packet; + + if (vmxferpagePacket->TransferPageSetId != NETVSC_RECEIVE_BUFFER_ID) + { + DPRINT_ERR(NETVSC, "Invalid xfer page set id - expecting %x got %x", NETVSC_RECEIVE_BUFFER_ID, vmxferpagePacket->TransferPageSetId); + PutNetDevice(Device); + return; + } + + DPRINT_DBG(NETVSC, "xfer page - range count %d", vmxferpagePacket->RangeCount); + + INITIALIZE_LIST_HEAD(&listHead); + + // Grab free packets (range count + 1) to represent this xfer page packet. +1 to represent + // the xfer page packet itself. We grab it here so that we know exactly how many we can fulfil + SpinlockAcquire(netDevice->ReceivePacketListLock); + while (!IsListEmpty(&netDevice->ReceivePacketList)) + { + entry = REMOVE_HEAD_LIST(&netDevice->ReceivePacketList); + netvscPacket = CONTAINING_RECORD(entry, NETVSC_PACKET, ListEntry); + + INSERT_TAIL_LIST(&listHead, &netvscPacket->ListEntry); + + if (++count == vmxferpagePacket->RangeCount + 1) + break; + } + SpinlockRelease(netDevice->ReceivePacketListLock); + + // We need at least 2 netvsc pkts (1 to represent the xfer page and at least 1 for the range) + // i.e. we can handled some of the xfer page packet ranges... + if (count < 2) + { + DPRINT_ERR(NETVSC, "Got only %d netvsc pkt...needed %d pkts. Dropping this xfer page packet completely!", count, vmxferpagePacket->RangeCount+1); + + // Return it to the freelist + SpinlockAcquire(netDevice->ReceivePacketListLock); + for (i=count; i != 0; i--) + { + entry = REMOVE_HEAD_LIST(&listHead); + netvscPacket = CONTAINING_RECORD(entry, NETVSC_PACKET, ListEntry); + + INSERT_TAIL_LIST(&netDevice->ReceivePacketList, &netvscPacket->ListEntry); + } + SpinlockRelease(netDevice->ReceivePacketListLock); + + NetVscSendReceiveCompletion(Device, vmxferpagePacket->d.TransactionId); + + PutNetDevice(Device); + return; + } + + // Remove the 1st packet to represent the xfer page packet itself + entry = REMOVE_HEAD_LIST(&listHead); + xferpagePacket = CONTAINING_RECORD(entry, XFERPAGE_PACKET, ListEntry); + xferpagePacket->Count = count - 1; // This is how much we can satisfy + ASSERT(xferpagePacket->Count > 0 && xferpagePacket->Count <= vmxferpagePacket->RangeCount); + + if (xferpagePacket->Count != vmxferpagePacket->RangeCount) + { + DPRINT_INFO(NETVSC, "Needed %d netvsc pkts to satisy this xfer page...got %d", vmxferpagePacket->RangeCount, xferpagePacket->Count); + } + + // Each range represents 1 RNDIS pkt that contains 1 ethernet frame + for (i=0; i < (count - 1); i++) + { + entry = REMOVE_HEAD_LIST(&listHead); + netvscPacket = CONTAINING_RECORD(entry, NETVSC_PACKET, ListEntry); + + // Initialize the netvsc packet + netvscPacket->XferPagePacket = xferpagePacket; + netvscPacket->Completion.Recv.OnReceiveCompletion = NetVscOnReceiveCompletion; + netvscPacket->Completion.Recv.ReceiveCompletionContext = netvscPacket; + netvscPacket->Device = Device; + netvscPacket->Completion.Recv.ReceiveCompletionTid = vmxferpagePacket->d.TransactionId; // Save this so that we can send it back + + netvscPacket->TotalDataBufferLength = vmxferpagePacket->Ranges[i].ByteCount; + netvscPacket->PageBufferCount = 1; + + ASSERT(vmxferpagePacket->Ranges[i].ByteOffset + vmxferpagePacket->Ranges[i].ByteCount < netDevice->ReceiveBufferSize); + + netvscPacket->PageBuffers[0].Length = vmxferpagePacket->Ranges[i].ByteCount; + + start = GetPhysicalAddress((void*)((ULONG_PTR)netDevice->ReceiveBuffer + vmxferpagePacket->Ranges[i].ByteOffset)); + + netvscPacket->PageBuffers[0].Pfn = start >> PAGE_SHIFT; + endVirtual = (ULONG_PTR)netDevice->ReceiveBuffer + + vmxferpagePacket->Ranges[i].ByteOffset + + vmxferpagePacket->Ranges[i].ByteCount -1; + end = GetPhysicalAddress((void*)endVirtual); + + // Calculate the page relative offset + netvscPacket->PageBuffers[0].Offset = vmxferpagePacket->Ranges[i].ByteOffset & (PAGE_SIZE -1); + if ((end >> PAGE_SHIFT) != (start>>PAGE_SHIFT)) { + //Handle frame across multiple pages: + netvscPacket->PageBuffers[0].Length = + (netvscPacket->PageBuffers[0].Pfn <TotalDataBufferLength - netvscPacket->PageBuffers[0].Length; + for (j=1; jPageBuffers[j].Offset = 0; + if (bytesRemain <= PAGE_SIZE) { + netvscPacket->PageBuffers[j].Length = bytesRemain; + bytesRemain = 0; + } else { + netvscPacket->PageBuffers[j].Length = PAGE_SIZE; + bytesRemain -= PAGE_SIZE; + } + netvscPacket->PageBuffers[j].Pfn = + GetPhysicalAddress((void*)(endVirtual - bytesRemain)) >> PAGE_SHIFT; + netvscPacket->PageBufferCount++; + if (bytesRemain == 0) + break; + } + ASSERT(bytesRemain == 0); + } + DPRINT_DBG(NETVSC, "[%d] - (abs offset %u len %u) => (pfn %llx, offset %u, len %u)", + i, + vmxferpagePacket->Ranges[i].ByteOffset, + vmxferpagePacket->Ranges[i].ByteCount, + netvscPacket->PageBuffers[0].Pfn, + netvscPacket->PageBuffers[0].Offset, + netvscPacket->PageBuffers[0].Length); + + // Pass it to the upper layer + ((NETVSC_DRIVER_OBJECT*)Device->Driver)->OnReceiveCallback(Device, netvscPacket); + + NetVscOnReceiveCompletion(netvscPacket->Completion.Recv.ReceiveCompletionContext); + } + + ASSERT(IsListEmpty(&listHead)); + + PutNetDevice(Device); + DPRINT_EXIT(NETVSC); +} + + +static void +NetVscSendReceiveCompletion( + DEVICE_OBJECT *Device, + UINT64 TransactionId + ) +{ + NVSP_MESSAGE recvcompMessage; + int retries=0; + int ret=0; + + DPRINT_DBG(NETVSC, "Sending receive completion pkt - %llx", TransactionId); + + recvcompMessage.Header.MessageType = NvspMessage1TypeSendRNDISPacketComplete; + + // FIXME: Pass in the status + recvcompMessage.Messages.Version1Messages.SendRNDISPacketComplete.Status = NvspStatusSuccess; + +retry_send_cmplt: + // Send the completion + ret = Device->Driver->VmbusChannelInterface.SendPacket(Device, + &recvcompMessage, + sizeof(NVSP_MESSAGE), + TransactionId, + VmbusPacketTypeCompletion, + 0); + if (ret == 0) // success + { + // no-op + } + else if (ret == -1) // no more room...wait a bit and attempt to retry 3 times + { + retries++; + DPRINT_ERR(NETVSC, "unable to send receive completion pkt (tid %llx)...retrying %d", TransactionId, retries); + + if (retries < 4) + { + Sleep(100); + goto retry_send_cmplt; + } + else + { + DPRINT_ERR(NETVSC, "unable to send receive completion pkt (tid %llx)...give up retrying", TransactionId); + } + } + else + { + DPRINT_ERR(NETVSC, "unable to send receive completion pkt - %llx", TransactionId); + } +} + +// +// Send a receive completion packet to RNDIS device (ie NetVsp) +// +static void +NetVscOnReceiveCompletion( + PVOID Context) +{ + NETVSC_PACKET *packet = (NETVSC_PACKET*)Context; + DEVICE_OBJECT *device = (DEVICE_OBJECT*)packet->Device; + NETVSC_DEVICE* netDevice; + UINT64 transactionId=0; + BOOL fSendReceiveComp = FALSE; + + DPRINT_ENTER(NETVSC); + + ASSERT(packet->XferPagePacket); + + // Even though it seems logical to do a GetOutboundNetDevice() here to send out receive completion, + // we are using GetInboundNetDevice() since we may have disable outbound traffic already. + netDevice = GetInboundNetDevice(device); + if (!netDevice) + { + DPRINT_ERR(NETVSC, "unable to get net device...device being destroyed?"); + DPRINT_EXIT(NETVSC); + return; + } + + // Overloading use of the lock. + SpinlockAcquire(netDevice->ReceivePacketListLock); + + ASSERT(packet->XferPagePacket->Count > 0); + packet->XferPagePacket->Count--; + + // Last one in the line that represent 1 xfer page packet. + // Return the xfer page packet itself to the freelist + if (packet->XferPagePacket->Count == 0) + { + fSendReceiveComp = TRUE; + transactionId = packet->Completion.Recv.ReceiveCompletionTid; + + INSERT_TAIL_LIST(&netDevice->ReceivePacketList, &packet->XferPagePacket->ListEntry); + } + + // Put the packet back + INSERT_TAIL_LIST(&netDevice->ReceivePacketList, &packet->ListEntry); + SpinlockRelease(netDevice->ReceivePacketListLock); + + // Send a receive completion for the xfer page packet + if (fSendReceiveComp) + { + NetVscSendReceiveCompletion(device, transactionId); + } + + PutNetDevice(device); + DPRINT_EXIT(NETVSC); +} + + + +void +NetVscOnChannelCallback( + PVOID Context + ) +{ + const int netPacketSize=2048; + int ret=0; + DEVICE_OBJECT *device=(DEVICE_OBJECT*)Context; + NETVSC_DEVICE *netDevice; + + UINT32 bytesRecvd; + UINT64 requestId; + UCHAR packet[netPacketSize]; + VMPACKET_DESCRIPTOR *desc; + UCHAR *buffer=packet; + int bufferlen=netPacketSize; + + + DPRINT_ENTER(NETVSC); + + ASSERT(device); + + netDevice = GetInboundNetDevice(device); + if (!netDevice) + { + DPRINT_ERR(NETVSC, "net device (%p) shutting down...ignoring inbound packets", netDevice); + DPRINT_EXIT(NETVSC); + return; + } + + do + { + ret = device->Driver->VmbusChannelInterface.RecvPacketRaw(device, + buffer, + bufferlen, + &bytesRecvd, + &requestId); + + if (ret == 0) + { + if (bytesRecvd > 0) + { + DPRINT_DBG(NETVSC, "receive %d bytes, tid %llx", bytesRecvd, requestId); + + desc = (VMPACKET_DESCRIPTOR*)buffer; + switch (desc->Type) + { + case VmbusPacketTypeCompletion: + NetVscOnSendCompletion(device, desc); + break; + + case VmbusPacketTypeDataUsingTransferPages: + NetVscOnReceive(device, desc); + break; + + default: + DPRINT_ERR(NETVSC, "unhandled packet type %d, tid %llx len %d\n", desc->Type, requestId, bytesRecvd); + break; + } + + // reset + if (bufferlen > netPacketSize) + { + MemFree(buffer); + + buffer = packet; + bufferlen = netPacketSize; + } + } + else + { + //DPRINT_DBG(NETVSC, "nothing else to read..."); + + // reset + if (bufferlen > netPacketSize) + { + MemFree(buffer); + + buffer = packet; + bufferlen = netPacketSize; + } + + break; + } + } + else if (ret == -2) // Handle large packet + { + buffer = MemAllocAtomic(bytesRecvd); + if (buffer == NULL) + { + // Try again next time around + DPRINT_ERR(NETVSC, "unable to allocate buffer of size (%d)!!", bytesRecvd); + break; + } + + bufferlen = bytesRecvd; + } + else + { + ASSERT(0); + } + } while (1); + + PutNetDevice(device); + DPRINT_EXIT(NETVSC); + return; +} --- /dev/null +++ b/drivers/staging/hv/netvsc_drv.c @@ -0,0 +1,720 @@ +/* + * + * Copyright (c) 2009, Microsoft Corporation. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along with + * this program; if not, write to the Free Software Foundation, Inc., 59 Temple + * Place - Suite 330, Boston, MA 02111-1307 USA. + * + * Authors: + * Hank Janssen + * + */ + + +#include +#include +#include +#include +#if defined(KERNEL_2_6_5) || defined(KERNEL_2_6_9) +#include +#else +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "logging.h" +#include "vmbus.h" + +#include "NetVscApi.h" + +MODULE_LICENSE("GPL"); + +// +// Static decl +// +static int netvsc_probe(struct device *device); +static int netvsc_remove(struct device *device); +static int netvsc_open(struct net_device *net); +static void netvsc_xmit_completion(void *context); +static int netvsc_start_xmit (struct sk_buff *skb, struct net_device *net); +static int netvsc_recv_callback(DEVICE_OBJECT *device_obj, NETVSC_PACKET* Packet); +static int netvsc_close(struct net_device *net); +static struct net_device_stats *netvsc_get_stats(struct net_device *net); +static void netvsc_linkstatus_callback(DEVICE_OBJECT *device_obj, unsigned int status); + +// +// Data types +// +struct net_device_context { + struct device_context *device_ctx; // point back to our device context + struct net_device_stats stats; +}; + +struct netvsc_driver_context { + // !! These must be the first 2 fields !! + struct driver_context drv_ctx; + NETVSC_DRIVER_OBJECT drv_obj; +}; + +// +// Globals +// + +static int netvsc_ringbuffer_size = NETVSC_DEVICE_RING_BUFFER_SIZE; + +// The one and only one +static struct netvsc_driver_context g_netvsc_drv; + +// +// Routines +// + +/*++ + +Name: netvsc_drv_init() + +Desc: NetVsc driver initialization + +--*/ +int netvsc_drv_init(PFN_DRIVERINITIALIZE pfn_drv_init) +{ + int ret=0; + NETVSC_DRIVER_OBJECT *net_drv_obj=&g_netvsc_drv.drv_obj; + struct driver_context *drv_ctx=&g_netvsc_drv.drv_ctx; + + DPRINT_ENTER(NETVSC_DRV); + + vmbus_get_interface(&net_drv_obj->Base.VmbusChannelInterface); + + net_drv_obj->RingBufferSize = netvsc_ringbuffer_size; + net_drv_obj->OnReceiveCallback = netvsc_recv_callback; + net_drv_obj->OnLinkStatusChanged = netvsc_linkstatus_callback; + + // Callback to client driver to complete the initialization + pfn_drv_init(&net_drv_obj->Base); + + drv_ctx->driver.name = net_drv_obj->Base.name; + memcpy(&drv_ctx->class_id, &net_drv_obj->Base.deviceType, sizeof(GUID)); + +#if defined(KERNEL_2_6_5) || defined(KERNEL_2_6_9) + drv_ctx->driver.probe = netvsc_probe; + drv_ctx->driver.remove = netvsc_remove; +#else + drv_ctx->probe = netvsc_probe; + drv_ctx->remove = netvsc_remove; +#endif + + // The driver belongs to vmbus + vmbus_child_driver_register(drv_ctx); + + DPRINT_EXIT(NETVSC_DRV); + + return ret; +} + +/*++ + +Name: netvsc_get_stats() + +Desc: Get the network stats + +--*/ +static struct net_device_stats *netvsc_get_stats(struct net_device *net) +{ + struct net_device_context *net_device_ctx = netdev_priv(net); + + return &net_device_ctx->stats; +} + +/*++ + +Name: netvsc_set_multicast_list() + +Desc: Set the multicast list + +Remark: No-op here +--*/ +static void netvsc_set_multicast_list(UNUSED_VAR(struct net_device *net)) +{ +} + + +/*++ + +Name: netvsc_probe() + +Desc: Add the specified new device to this driver + +--*/ +static int netvsc_probe(struct device *device) +{ + int ret=0; + + struct driver_context *driver_ctx = driver_to_driver_context(device->driver); + struct netvsc_driver_context *net_drv_ctx = (struct netvsc_driver_context*)driver_ctx; + NETVSC_DRIVER_OBJECT *net_drv_obj = &net_drv_ctx->drv_obj; + + struct device_context *device_ctx = device_to_device_context(device); + DEVICE_OBJECT *device_obj = &device_ctx->device_obj; + + struct net_device *net = NULL; + struct net_device_context *net_device_ctx; + NETVSC_DEVICE_INFO device_info; + + DPRINT_ENTER(NETVSC_DRV); + + if (!net_drv_obj->Base.OnDeviceAdd) + { + return -1; + } + + net = alloc_netdev(sizeof(struct net_device_context), "seth%d", ether_setup); + //net = alloc_etherdev(sizeof(struct net_device_context)); + if (!net) + { + return -1; + } + + // Set initial state + netif_carrier_off(net); + netif_stop_queue(net); + + net_device_ctx = netdev_priv(net); + net_device_ctx->device_ctx = device_ctx; + device->driver_data = net; + + // Notify the netvsc driver of the new device + ret = net_drv_obj->Base.OnDeviceAdd(device_obj, (void*)&device_info); + if (ret != 0) + { + free_netdev(net); + device->driver_data = NULL; + + DPRINT_ERR(NETVSC_DRV, "unable to add netvsc device (ret %d)", ret); + return ret; + } + + // If carrier is still off ie we did not get a link status callback, update it if necessary + // FIXME: We should use a atomic or test/set instead to avoid getting out of sync with the device's link status + if (!netif_carrier_ok(net)) + { + if (!device_info.LinkState) + { + netif_carrier_on(net); + } + } + + memcpy(net->dev_addr, device_info.MacAddr, ETH_ALEN); + + net->open = netvsc_open; + net->hard_start_xmit = netvsc_start_xmit; + net->stop = netvsc_close; + net->get_stats = netvsc_get_stats; + net->set_multicast_list = netvsc_set_multicast_list; + +#if !defined(KERNEL_2_6_27) + SET_MODULE_OWNER(net); +#endif + SET_NETDEV_DEV(net, device); + + ret = register_netdev(net); + if (ret != 0) + { + // Remove the device and release the resource + net_drv_obj->Base.OnDeviceRemove(device_obj); + free_netdev(net); + } + + DPRINT_EXIT(NETVSC_DRV); + + return ret; +} + +static int netvsc_remove(struct device *device) +{ + int ret=0; + struct driver_context *driver_ctx = driver_to_driver_context(device->driver); + struct netvsc_driver_context *net_drv_ctx = (struct netvsc_driver_context*)driver_ctx; + NETVSC_DRIVER_OBJECT *net_drv_obj = &net_drv_ctx->drv_obj; + + struct device_context *device_ctx = device_to_device_context(device); + struct net_device *net = (struct net_device *)device_ctx->device.driver_data; + DEVICE_OBJECT *device_obj = &device_ctx->device_obj; + + DPRINT_ENTER(NETVSC_DRV); + + if (net == NULL) + { + DPRINT_INFO(NETVSC, "no net device to remove"); + DPRINT_EXIT(NETVSC_DRV); + return 0; + } + + if (!net_drv_obj->Base.OnDeviceRemove) + { + DPRINT_EXIT(NETVSC_DRV); + return -1; + } + + // Stop outbound asap + netif_stop_queue(net); + //netif_carrier_off(net); + + unregister_netdev(net); + + // Call to the vsc driver to let it know that the device is being removed + ret = net_drv_obj->Base.OnDeviceRemove(device_obj); + if (ret != 0) + { + // TODO: + DPRINT_ERR(NETVSC, "unable to remove vsc device (ret %d)", ret); + } + + free_netdev(net); + + DPRINT_EXIT(NETVSC_DRV); + + return ret; +} + +/*++ + +Name: netvsc_open() + +Desc: Open the specified interface device + +--*/ +static int netvsc_open(struct net_device *net) +{ + int ret=0; + struct net_device_context *net_device_ctx = netdev_priv(net); + struct driver_context *driver_ctx = driver_to_driver_context(net_device_ctx->device_ctx->device.driver); + struct netvsc_driver_context *net_drv_ctx = (struct netvsc_driver_context*)driver_ctx; + NETVSC_DRIVER_OBJECT *net_drv_obj = &net_drv_ctx->drv_obj; + + DEVICE_OBJECT *device_obj = &net_device_ctx->device_ctx->device_obj; + + DPRINT_ENTER(NETVSC_DRV); + + if (netif_carrier_ok(net)) + { + memset(&net_device_ctx->stats, 0 , sizeof(struct net_device_stats)); + + // Open up the device + ret = net_drv_obj->OnOpen(device_obj); + if (ret != 0) + { + DPRINT_ERR(NETVSC_DRV, "unable to open device (ret %d).", ret); + return ret; + } + + netif_start_queue(net); + } + else + { + DPRINT_ERR(NETVSC_DRV, "unable to open device...link is down."); + } + + DPRINT_EXIT(NETVSC_DRV); + return ret; +} + +/*++ + +Name: netvsc_close() + +Desc: Close the specified interface device + +--*/ +static int netvsc_close(struct net_device *net) +{ + int ret=0; + struct net_device_context *net_device_ctx = netdev_priv(net); + struct driver_context *driver_ctx = driver_to_driver_context(net_device_ctx->device_ctx->device.driver); + struct netvsc_driver_context *net_drv_ctx = (struct netvsc_driver_context*)driver_ctx; + NETVSC_DRIVER_OBJECT *net_drv_obj = &net_drv_ctx->drv_obj; + + DEVICE_OBJECT *device_obj = &net_device_ctx->device_ctx->device_obj; + + DPRINT_ENTER(NETVSC_DRV); + + netif_stop_queue(net); + + ret = net_drv_obj->OnClose(device_obj); + if (ret != 0) + { + DPRINT_ERR(NETVSC_DRV, "unable to close device (ret %d).", ret); + } + + DPRINT_EXIT(NETVSC_DRV); + + return ret; +} + + +/*++ + +Name: netvsc_xmit_completion() + +Desc: Send completion processing + +--*/ +static void netvsc_xmit_completion(void *context) +{ + NETVSC_PACKET *packet = (NETVSC_PACKET *)context; + struct sk_buff *skb = (struct sk_buff *)(ULONG_PTR)packet->Completion.Send.SendCompletionTid; + struct net_device* net; + + DPRINT_ENTER(NETVSC_DRV); + + kfree(packet); + + if (skb) + { + net = skb->dev; + + dev_kfree_skb_any(skb); + + if (netif_queue_stopped(net)) + { + DPRINT_INFO(NETVSC_DRV, "net device (%p) waking up...", net); + + netif_wake_queue(net); + } + } + + DPRINT_EXIT(NETVSC_DRV); +} + +/*++ + +Name: netvsc_start_xmit() + +Desc: Start a send + +--*/ +static int netvsc_start_xmit (struct sk_buff *skb, struct net_device *net) +{ + int ret=0; + struct net_device_context *net_device_ctx = netdev_priv(net); + struct driver_context *driver_ctx = driver_to_driver_context(net_device_ctx->device_ctx->device.driver); + struct netvsc_driver_context *net_drv_ctx = (struct netvsc_driver_context*)driver_ctx; + NETVSC_DRIVER_OBJECT *net_drv_obj = &net_drv_ctx->drv_obj; + + int i=0; + NETVSC_PACKET* packet; + int num_frags; + int retries=0; + + DPRINT_ENTER(NETVSC_DRV); + + // Support only 1 chain of frags + ASSERT(skb_shinfo(skb)->frag_list == NULL); + ASSERT(skb->dev == net); + + DPRINT_DBG(NETVSC_DRV, "xmit packet - len %d data_len %d", skb->len, skb->data_len); + + // Add 1 for skb->data and any additional ones requested + num_frags = skb_shinfo(skb)->nr_frags + 1 + net_drv_obj->AdditionalRequestPageBufferCount; + + // Allocate a netvsc packet based on # of frags. + packet = kzalloc(sizeof(NETVSC_PACKET) + (num_frags * sizeof(PAGE_BUFFER)) + net_drv_obj->RequestExtSize, GFP_ATOMIC); + if (!packet) + { + DPRINT_ERR(NETVSC_DRV, "unable to allocate NETVSC_PACKET"); + return -1; + } + + packet->Extension = (void*)(unsigned long)packet + sizeof(NETVSC_PACKET) + (num_frags * sizeof(PAGE_BUFFER)) ; + + // Setup the rndis header + packet->PageBufferCount = num_frags; + + // TODO: Flush all write buffers/ memory fence ??? + //wmb(); + + // Initialize it from the skb + ASSERT(skb->data); + packet->TotalDataBufferLength = skb->len; + + // Start filling in the page buffers starting at AdditionalRequestPageBufferCount offset + packet->PageBuffers[net_drv_obj->AdditionalRequestPageBufferCount].Pfn = virt_to_phys(skb->data) >> PAGE_SHIFT; + packet->PageBuffers[net_drv_obj->AdditionalRequestPageBufferCount].Offset = (unsigned long)skb->data & (PAGE_SIZE -1); + packet->PageBuffers[net_drv_obj->AdditionalRequestPageBufferCount].Length = skb->len - skb->data_len; + + ASSERT((skb->len - skb->data_len) <= PAGE_SIZE); + + for (i=net_drv_obj->AdditionalRequestPageBufferCount+1; iPageBuffers[i].Pfn = page_to_pfn(skb_shinfo(skb)->frags[i-(net_drv_obj->AdditionalRequestPageBufferCount+1)].page); + packet->PageBuffers[i].Offset = skb_shinfo(skb)->frags[i-(net_drv_obj->AdditionalRequestPageBufferCount+1)].page_offset; + packet->PageBuffers[i].Length = skb_shinfo(skb)->frags[i-(net_drv_obj->AdditionalRequestPageBufferCount+1)].size; + } + + // Set the completion routine + packet->Completion.Send.OnSendCompletion = netvsc_xmit_completion; + packet->Completion.Send.SendCompletionContext = packet; + packet->Completion.Send.SendCompletionTid = (ULONG_PTR)skb; + +retry_send: + ret = net_drv_obj->OnSend(&net_device_ctx->device_ctx->device_obj, packet); + + if (ret == 0) + { +#ifdef KERNEL_2_6_5 +#define NETDEV_TX_OK 0 +#define NETDEV_TX_BUSY 0 +#endif + ret = NETDEV_TX_OK; + net_device_ctx->stats.tx_bytes += skb->len; + net_device_ctx->stats.tx_packets++; + } + else + { + retries++; + if (retries < 4) + { + DPRINT_ERR(NETVSC_DRV, "unable to send...retrying %d...", retries); + udelay(100); + goto retry_send; + } + + // no more room or we are shutting down + DPRINT_ERR(NETVSC_DRV, "unable to send (%d)...marking net device (%p) busy", ret, net); + DPRINT_INFO(NETVSC_DRV, "net device (%p) stopping", net); + + ret = NETDEV_TX_BUSY; + net_device_ctx->stats.tx_dropped++; + + netif_stop_queue(net); + + // Null it since the caller will free it instead of the completion routine + packet->Completion.Send.SendCompletionTid = 0; + + // Release the resources since we will not get any send completion + netvsc_xmit_completion((void*)packet); + } + + DPRINT_DBG(NETVSC_DRV, "# of xmits %lu total size %lu", net_device_ctx->stats.tx_packets, net_device_ctx->stats.tx_bytes); + + DPRINT_EXIT(NETVSC_DRV); + return ret; +} + + +/*++ + +Name: netvsc_linkstatus_callback() + +Desc: Link up/down notification + +--*/ +static void netvsc_linkstatus_callback(DEVICE_OBJECT *device_obj, unsigned int status) +{ + struct device_context* device_ctx = to_device_context(device_obj); + struct net_device* net = (struct net_device *)device_ctx->device.driver_data; + + DPRINT_ENTER(NETVSC_DRV); + + if (!net) + { + DPRINT_ERR(NETVSC_DRV, "got link status but net device not initialized yet"); + return; + } + + if (status == 1) + { + netif_carrier_on(net); + netif_wake_queue(net); + } + else + { + netif_carrier_off(net); + netif_stop_queue(net); + } + DPRINT_EXIT(NETVSC_DRV); +} + + +/*++ + +Name: netvsc_recv_callback() + +Desc: Callback when we receive a packet from the "wire" on the specify device + +--*/ +static int netvsc_recv_callback(DEVICE_OBJECT *device_obj, NETVSC_PACKET* packet) +{ + int ret=0; + struct device_context *device_ctx = to_device_context(device_obj); + struct net_device *net = (struct net_device *)device_ctx->device.driver_data; + struct net_device_context *net_device_ctx; + + struct sk_buff *skb; + void *data; + int i=0; + unsigned long flags; + + DPRINT_ENTER(NETVSC_DRV); + + if (!net) + { + DPRINT_ERR(NETVSC_DRV, "got receive callback but net device not initialized yet"); + return 0; + } + + net_device_ctx = netdev_priv(net); + + // Allocate a skb - TODO preallocate this + //skb = alloc_skb(packet->TotalDataBufferLength, GFP_ATOMIC); + skb = dev_alloc_skb(packet->TotalDataBufferLength + 2); // Pad 2-bytes to align IP header to 16 bytes + ASSERT(skb); + skb_reserve(skb, 2); + skb->dev = net; + + // for kmap_atomic + local_irq_save(flags); + + // Copy to skb. This copy is needed here since the memory pointed by NETVSC_PACKET + // cannot be deallocated + for (i=0; iPageBufferCount; i++) + { + data = kmap_atomic(pfn_to_page(packet->PageBuffers[i].Pfn), KM_IRQ1); + data = (void*)(unsigned long)data + packet->PageBuffers[i].Offset; + + memcpy(skb_put(skb, packet->PageBuffers[i].Length), data, packet->PageBuffers[i].Length); + + kunmap_atomic((void*)((unsigned long)data - packet->PageBuffers[i].Offset), KM_IRQ1); + } + + local_irq_restore(flags); + + skb->protocol = eth_type_trans(skb, net); + + skb->ip_summed = CHECKSUM_NONE; + + // Pass the skb back up. Network stack will deallocate the skb when it is done + ret = netif_rx(skb); + + switch (ret) + { + case NET_RX_DROP: + net_device_ctx->stats.rx_dropped++; + break; + default: + net_device_ctx->stats.rx_packets++; + net_device_ctx->stats.rx_bytes += skb->len; + break; + + } + DPRINT_DBG(NETVSC_DRV, "# of recvs %lu total size %lu", net_device_ctx->stats.rx_packets, net_device_ctx->stats.rx_bytes); + + DPRINT_EXIT(NETVSC_DRV); + + return 0; +} + +static int netvsc_drv_exit_cb(struct device *dev, void *data) +{ + struct device **curr = (struct device **)data; + *curr = dev; + return 1; // stop iterating +} + +/*++ + +Name: netvsc_drv_exit() + +Desc: + +--*/ +void netvsc_drv_exit(void) +{ + NETVSC_DRIVER_OBJECT *netvsc_drv_obj=&g_netvsc_drv.drv_obj; + struct driver_context *drv_ctx=&g_netvsc_drv.drv_ctx; + + struct device *current_dev=NULL; +#if defined(KERNEL_2_6_5) || defined(KERNEL_2_6_9) +#define driver_for_each_device(drv, start, data, fn) \ + struct list_head *ptr, *n; \ + list_for_each_safe(ptr, n, &((drv)->devices)) {\ + struct device *curr_dev;\ + curr_dev = list_entry(ptr, struct device, driver_list);\ + fn(curr_dev, data);\ + } +#endif + + DPRINT_ENTER(NETVSC_DRV); + + while (1) + { + current_dev = NULL; + + // Get the device + driver_for_each_device(&drv_ctx->driver, NULL, (void*)¤t_dev, netvsc_drv_exit_cb); + + if (current_dev == NULL) + break; + + // Initiate removal from the top-down + DPRINT_INFO(NETVSC_DRV, "unregistering device (%p)...", current_dev); + + device_unregister(current_dev); + } + + if (netvsc_drv_obj->Base.OnCleanup) + netvsc_drv_obj->Base.OnCleanup(&netvsc_drv_obj->Base); + + vmbus_child_driver_unregister(drv_ctx); + + DPRINT_EXIT(NETVSC_DRV); + + return; +} + +static int __init netvsc_init(void) +{ + int ret; + + DPRINT_ENTER(NETVSC_DRV); + DPRINT_INFO(NETVSC_DRV, "Netvsc initializing...."); + + ret = netvsc_drv_init(NetVscInitialize); + + DPRINT_EXIT(NETVSC_DRV); + + return ret; +} + +static void __exit netvsc_exit(void) +{ + DPRINT_ENTER(NETVSC_DRV); + + netvsc_drv_exit(); + + DPRINT_EXIT(NETVSC_DRV); +} + +module_param(netvsc_ringbuffer_size, int, S_IRUGO); + +module_init(netvsc_init); +module_exit(netvsc_exit); --- /dev/null +++ b/drivers/staging/hv/NetVsc.h @@ -0,0 +1,91 @@ +/* + * + * Copyright (c) 2009, Microsoft Corporation. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along with + * this program; if not, write to the Free Software Foundation, Inc., 59 Temple + * Place - Suite 330, Boston, MA 02111-1307 USA. + * + * Authors: + * Hank Janssen + * + */ + + +#ifndef _NETVSC_H_ +#define _NETVSC_H_ + +#include "VmbusPacketFormat.h" +#include "nvspprotocol.h" + +#include "List.h" + +#include "NetVscApi.h" +// +// #defines +// +//#define NVSC_MIN_PROTOCOL_VERSION 1 +//#define NVSC_MAX_PROTOCOL_VERSION 1 + +#define NETVSC_SEND_BUFFER_SIZE 64*1024 // 64K +#define NETVSC_SEND_BUFFER_ID 0xface + + +#define NETVSC_RECEIVE_BUFFER_SIZE 1024*1024 // 1MB + +#define NETVSC_RECEIVE_BUFFER_ID 0xcafe + +#define NETVSC_RECEIVE_SG_COUNT 1 + +// Preallocated receive packets +#define NETVSC_RECEIVE_PACKETLIST_COUNT 256 + +// +// Data types +// + +// Per netvsc channel-specific +typedef struct _NETVSC_DEVICE { + DEVICE_OBJECT *Device; + + int RefCount; + + int NumOutstandingSends; + // List of free preallocated NETVSC_PACKET to represent receive packet + LIST_ENTRY ReceivePacketList; + HANDLE ReceivePacketListLock; + + // Send buffer allocated by us but manages by NetVSP + PVOID SendBuffer; + UINT32 SendBufferSize; + UINT32 SendBufferGpadlHandle; + UINT32 SendSectionSize; + + // Receive buffer allocated by us but manages by NetVSP + PVOID ReceiveBuffer; + UINT32 ReceiveBufferSize; + UINT32 ReceiveBufferGpadlHandle; + UINT32 ReceiveSectionCount; + PNVSP_1_RECEIVE_BUFFER_SECTION ReceiveSections; + + // Used for NetVSP initialization protocol + HANDLE ChannelInitEvent; + NVSP_MESSAGE ChannelInitPacket; + + NVSP_MESSAGE RevokePacket; + //UCHAR HwMacAddr[HW_MACADDR_LEN]; + + // Holds rndis device info + void *Extension; +} NETVSC_DEVICE; + +#endif // _NETVSC_H_ --- /dev/null +++ b/drivers/staging/hv/RndisFilter.c @@ -0,0 +1,1162 @@ +/* + * + * Copyright (c) 2009, Microsoft Corporation. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along with + * this program; if not, write to the Free Software Foundation, Inc., 59 Temple + * Place - Suite 330, Boston, MA 02111-1307 USA. + * + * Authors: + * Haiyang Zhang + * Hank Janssen + * + */ + + +#include "logging.h" + +#include "NetVscApi.h" +#include "RndisFilter.h" + +// +// Data types +// + +typedef struct _RNDIS_FILTER_DRIVER_OBJECT { + // The original driver + NETVSC_DRIVER_OBJECT InnerDriver; + +} RNDIS_FILTER_DRIVER_OBJECT; + +typedef enum { + RNDIS_DEV_UNINITIALIZED = 0, + RNDIS_DEV_INITIALIZING, + RNDIS_DEV_INITIALIZED, + RNDIS_DEV_DATAINITIALIZED, +} RNDIS_DEVICE_STATE; + +typedef struct _RNDIS_DEVICE { + NETVSC_DEVICE *NetDevice; + + RNDIS_DEVICE_STATE State; + UINT32 LinkStatus; + UINT32 NewRequestId; + + HANDLE RequestLock; + LIST_ENTRY RequestList; + + UCHAR HwMacAddr[HW_MACADDR_LEN]; +} RNDIS_DEVICE; + + +typedef struct _RNDIS_REQUEST { + LIST_ENTRY ListEntry; + HANDLE WaitEvent; + + // FIXME: We assumed a fixed size response here. If we do ever need to handle a bigger response, + // we can either define a max response message or add a response buffer variable above this field + RNDIS_MESSAGE ResponseMessage; + + // Simplify allocation by having a netvsc packet inline + NETVSC_PACKET Packet; + PAGE_BUFFER Buffer; + // FIXME: We assumed a fixed size request here. + RNDIS_MESSAGE RequestMessage; +} RNDIS_REQUEST; + + +typedef struct _RNDIS_FILTER_PACKET { + void *CompletionContext; + PFN_ON_SENDRECVCOMPLETION OnCompletion; + + RNDIS_MESSAGE Message; +} RNDIS_FILTER_PACKET; + +// +// Internal routines +// +static int +RndisFilterSendRequest( + RNDIS_DEVICE *Device, + RNDIS_REQUEST *Request + ); + +static void +RndisFilterReceiveResponse( + RNDIS_DEVICE *Device, + RNDIS_MESSAGE *Response + ); + +static void +RndisFilterReceiveIndicateStatus( + RNDIS_DEVICE *Device, + RNDIS_MESSAGE *Response + ); + +static void +RndisFilterReceiveData( + RNDIS_DEVICE *Device, + RNDIS_MESSAGE *Message, + NETVSC_PACKET *Packet + ); + +static int +RndisFilterOnReceive( + DEVICE_OBJECT *Device, + NETVSC_PACKET *Packet + ); + +static int +RndisFilterQueryDevice( + RNDIS_DEVICE *Device, + UINT32 Oid, + VOID *Result, + UINT32 *ResultSize + ); + +static inline int +RndisFilterQueryDeviceMac( + RNDIS_DEVICE *Device + ); + +static inline int +RndisFilterQueryDeviceLinkStatus( + RNDIS_DEVICE *Device + ); + +static int +RndisFilterSetPacketFilter( + RNDIS_DEVICE *Device, + UINT32 NewFilter + ); + +static int +RndisFilterInitDevice( + RNDIS_DEVICE *Device + ); + +static int +RndisFilterOpenDevice( + RNDIS_DEVICE *Device + ); + +static int +RndisFilterCloseDevice( + RNDIS_DEVICE *Device + ); + +static int +RndisFilterOnDeviceAdd( + DEVICE_OBJECT *Device, + void *AdditionalInfo + ); + +static int +RndisFilterOnDeviceRemove( + DEVICE_OBJECT *Device + ); + +static void +RndisFilterOnCleanup( + DRIVER_OBJECT *Driver + ); + +static int +RndisFilterOnOpen( + DEVICE_OBJECT *Device + ); + +static int +RndisFilterOnClose( + DEVICE_OBJECT *Device + ); + +static int +RndisFilterOnSend( + DEVICE_OBJECT *Device, + NETVSC_PACKET *Packet + ); + +static void +RndisFilterOnSendCompletion( + void *Context + ); + +static void +RndisFilterOnSendRequestCompletion( + void *Context + ); + +// +// Global var +// + +// The one and only +RNDIS_FILTER_DRIVER_OBJECT gRndisFilter; + +static inline RNDIS_DEVICE* GetRndisDevice(void) +{ + RNDIS_DEVICE *device; + + device = MemAllocZeroed(sizeof(RNDIS_DEVICE)); + if (!device) + { + return NULL; + } + + device->RequestLock = SpinlockCreate(); + if (!device->RequestLock) + { + MemFree(device); + return NULL; + } + + INITIALIZE_LIST_HEAD(&device->RequestList); + + device->State = RNDIS_DEV_UNINITIALIZED; + + return device; +} + +static inline void PutRndisDevice(RNDIS_DEVICE *Device) +{ + SpinlockClose(Device->RequestLock); + MemFree(Device); +} + +static inline RNDIS_REQUEST* GetRndisRequest(RNDIS_DEVICE *Device, UINT32 MessageType, UINT32 MessageLength) +{ + RNDIS_REQUEST *request; + RNDIS_MESSAGE *rndisMessage; + RNDIS_SET_REQUEST *set; + + request = MemAllocZeroed(sizeof(RNDIS_REQUEST)); + if (!request) + { + return NULL; + } + + request->WaitEvent = WaitEventCreate(); + if (!request->WaitEvent) + { + MemFree(request); + return NULL; + } + + rndisMessage = &request->RequestMessage; + rndisMessage->NdisMessageType = MessageType; + rndisMessage->MessageLength = MessageLength; + + // Set the request id. This field is always after the rndis header for request/response packet types so + // we just used the SetRequest as a template + set = &rndisMessage->Message.SetRequest; + set->RequestId = InterlockedIncrement((int*)&Device->NewRequestId); + + // Add to the request list + SpinlockAcquire(Device->RequestLock); + INSERT_TAIL_LIST(&Device->RequestList, &request->ListEntry); + SpinlockRelease(Device->RequestLock); + + return request; +} + +static inline void PutRndisRequest(RNDIS_DEVICE *Device, RNDIS_REQUEST *Request) +{ + SpinlockAcquire(Device->RequestLock); + REMOVE_ENTRY_LIST(&Request->ListEntry); + SpinlockRelease(Device->RequestLock); + + WaitEventClose(Request->WaitEvent); + MemFree(Request); +} + +static inline void DumpRndisMessage(RNDIS_MESSAGE *RndisMessage) +{ + switch (RndisMessage->NdisMessageType) + { + case REMOTE_NDIS_PACKET_MSG: + DPRINT_DBG(NETVSC, "REMOTE_NDIS_PACKET_MSG (len %u, data offset %u data len %u, # oob %u, oob offset %u, oob len %u, pkt offset %u, pkt len %u", + RndisMessage->MessageLength, + RndisMessage->Message.Packet.DataOffset, + RndisMessage->Message.Packet.DataLength, + RndisMessage->Message.Packet.NumOOBDataElements, + RndisMessage->Message.Packet.OOBDataOffset, + RndisMessage->Message.Packet.OOBDataLength, + RndisMessage->Message.Packet.PerPacketInfoOffset, + RndisMessage->Message.Packet.PerPacketInfoLength); + break; + + case REMOTE_NDIS_INITIALIZE_CMPLT: + DPRINT_DBG(NETVSC, "REMOTE_NDIS_INITIALIZE_CMPLT (len %u, id 0x%x, status 0x%x, major %d, minor %d, device flags %d, max xfer size 0x%x, max pkts %u, pkt aligned %u)", + RndisMessage->MessageLength, + RndisMessage->Message.InitializeComplete.RequestId, + RndisMessage->Message.InitializeComplete.Status, + RndisMessage->Message.InitializeComplete.MajorVersion, + RndisMessage->Message.InitializeComplete.MinorVersion, + RndisMessage->Message.InitializeComplete.DeviceFlags, + RndisMessage->Message.InitializeComplete.MaxTransferSize, + RndisMessage->Message.InitializeComplete.MaxPacketsPerMessage, + RndisMessage->Message.InitializeComplete.PacketAlignmentFactor); + break; + + case REMOTE_NDIS_QUERY_CMPLT: + DPRINT_DBG(NETVSC, "REMOTE_NDIS_QUERY_CMPLT (len %u, id 0x%x, status 0x%x, buf len %u, buf offset %u)", + RndisMessage->MessageLength, + RndisMessage->Message.QueryComplete.RequestId, + RndisMessage->Message.QueryComplete.Status, + RndisMessage->Message.QueryComplete.InformationBufferLength, + RndisMessage->Message.QueryComplete.InformationBufferOffset); + break; + + case REMOTE_NDIS_SET_CMPLT: + DPRINT_DBG(NETVSC, "REMOTE_NDIS_SET_CMPLT (len %u, id 0x%x, status 0x%x)", + RndisMessage->MessageLength, + RndisMessage->Message.SetComplete.RequestId, + RndisMessage->Message.SetComplete.Status); + break; + + case REMOTE_NDIS_INDICATE_STATUS_MSG: + DPRINT_DBG(NETVSC, "REMOTE_NDIS_INDICATE_STATUS_MSG (len %u, status 0x%x, buf len %u, buf offset %u)", + RndisMessage->MessageLength, + RndisMessage->Message.IndicateStatus.Status, + RndisMessage->Message.IndicateStatus.StatusBufferLength, + RndisMessage->Message.IndicateStatus.StatusBufferOffset); + break; + + default: + DPRINT_DBG(NETVSC, "0x%x (len %u)", + RndisMessage->NdisMessageType, + RndisMessage->MessageLength); + break; + } +} + +static int +RndisFilterSendRequest( + RNDIS_DEVICE *Device, + RNDIS_REQUEST *Request + ) +{ + int ret=0; + NETVSC_PACKET *packet; + + DPRINT_ENTER(NETVSC); + + // Setup the packet to send it + packet = &Request->Packet; + + packet->IsDataPacket = FALSE; + packet->TotalDataBufferLength = Request->RequestMessage.MessageLength; + packet->PageBufferCount = 1; + + packet->PageBuffers[0].Pfn = GetPhysicalAddress(&Request->RequestMessage) >> PAGE_SHIFT; + packet->PageBuffers[0].Length = Request->RequestMessage.MessageLength; + packet->PageBuffers[0].Offset = (ULONG_PTR)&Request->RequestMessage & (PAGE_SIZE -1); + + packet->Completion.Send.SendCompletionContext = Request;//packet; + packet->Completion.Send.OnSendCompletion = RndisFilterOnSendRequestCompletion; + packet->Completion.Send.SendCompletionTid = (ULONG_PTR)Device; + + ret = gRndisFilter.InnerDriver.OnSend(Device->NetDevice->Device, packet); + DPRINT_EXIT(NETVSC); + return ret; +} + + +static void +RndisFilterReceiveResponse( + RNDIS_DEVICE *Device, + RNDIS_MESSAGE *Response + ) +{ + LIST_ENTRY *anchor; + LIST_ENTRY *curr; + RNDIS_REQUEST *request=NULL; + BOOL found=FALSE; + + DPRINT_ENTER(NETVSC); + + SpinlockAcquire(Device->RequestLock); + ITERATE_LIST_ENTRIES(anchor, curr, &Device->RequestList) + { + request = CONTAINING_RECORD(curr, RNDIS_REQUEST, ListEntry); + + // All request/response message contains RequestId as the 1st field + if (request->RequestMessage.Message.InitializeRequest.RequestId == Response->Message.InitializeComplete.RequestId) + { + DPRINT_DBG(NETVSC, "found rndis request for this response (id 0x%x req type 0x%x res type 0x%x)", + request->RequestMessage.Message.InitializeRequest.RequestId, request->RequestMessage.NdisMessageType, Response->NdisMessageType); + + found = TRUE; + break; + } + } + SpinlockRelease(Device->RequestLock); + + if (found) + { + if (Response->MessageLength <= sizeof(RNDIS_MESSAGE)) + { + memcpy(&request->ResponseMessage, Response, Response->MessageLength); + } + else + { + DPRINT_ERR(NETVSC, "rndis response buffer overflow detected (size %u max %u)", Response->MessageLength, sizeof(RNDIS_FILTER_PACKET)); + + if (Response->NdisMessageType == REMOTE_NDIS_RESET_CMPLT) // does not have a request id field + { + request->ResponseMessage.Message.ResetComplete.Status = STATUS_BUFFER_OVERFLOW; + } + else + { + request->ResponseMessage.Message.InitializeComplete.Status = STATUS_BUFFER_OVERFLOW; + } + } + + WaitEventSet(request->WaitEvent); + } + else + { + DPRINT_ERR(NETVSC, "no rndis request found for this response (id 0x%x res type 0x%x)", + Response->Message.InitializeComplete.RequestId, Response->NdisMessageType); + } + + DPRINT_EXIT(NETVSC); +} + +static void +RndisFilterReceiveIndicateStatus( + RNDIS_DEVICE *Device, + RNDIS_MESSAGE *Response + ) +{ + RNDIS_INDICATE_STATUS *indicate = &Response->Message.IndicateStatus; + + if (indicate->Status == RNDIS_STATUS_MEDIA_CONNECT) + { + gRndisFilter.InnerDriver.OnLinkStatusChanged(Device->NetDevice->Device, 1); + } + else if (indicate->Status == RNDIS_STATUS_MEDIA_DISCONNECT) + { + gRndisFilter.InnerDriver.OnLinkStatusChanged(Device->NetDevice->Device, 0); + } + else + { + // TODO: + } +} + +static void +RndisFilterReceiveData( + RNDIS_DEVICE *Device, + RNDIS_MESSAGE *Message, + NETVSC_PACKET *Packet + ) +{ + RNDIS_PACKET *rndisPacket; + UINT32 dataOffset; + + DPRINT_ENTER(NETVSC); + + // empty ethernet frame ?? + ASSERT(Packet->PageBuffers[0].Length > RNDIS_MESSAGE_SIZE(RNDIS_PACKET)); + + rndisPacket = &Message->Message.Packet; + + // FIXME: Handle multiple rndis pkt msgs that maybe enclosed in this + // netvsc packet (ie TotalDataBufferLength != MessageLength) + + // Remove the rndis header and pass it back up the stack + dataOffset = RNDIS_HEADER_SIZE + rndisPacket->DataOffset; + + Packet->TotalDataBufferLength -= dataOffset; + Packet->PageBuffers[0].Offset += dataOffset; + Packet->PageBuffers[0].Length -= dataOffset; + + Packet->IsDataPacket = TRUE; + + gRndisFilter.InnerDriver.OnReceiveCallback(Device->NetDevice->Device, Packet); + + DPRINT_EXIT(NETVSC); +} + +static int +RndisFilterOnReceive( + DEVICE_OBJECT *Device, + NETVSC_PACKET *Packet + ) +{ + NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension; + RNDIS_DEVICE *rndisDevice; + RNDIS_MESSAGE rndisMessage; + RNDIS_MESSAGE *rndisHeader; + + DPRINT_ENTER(NETVSC); + + ASSERT(netDevice); + //Make sure the rndis device state is initialized + if (!netDevice->Extension) + { + DPRINT_ERR(NETVSC, "got rndis message but no rndis device...dropping this message!"); + DPRINT_EXIT(NETVSC); + return -1; + } + + rndisDevice = (RNDIS_DEVICE*)netDevice->Extension; + if (rndisDevice->State == RNDIS_DEV_UNINITIALIZED) + { + DPRINT_ERR(NETVSC, "got rndis message but rndis device uninitialized...dropping this message!"); + DPRINT_EXIT(NETVSC); + return -1; + } + + rndisHeader = (RNDIS_MESSAGE*)PageMapVirtualAddress(Packet->PageBuffers[0].Pfn); + + rndisHeader = (void*)((ULONG_PTR)rndisHeader + Packet->PageBuffers[0].Offset); + + // Make sure we got a valid rndis message + // FIXME: There seems to be a bug in set completion msg where its MessageLength is 16 bytes but + // the ByteCount field in the xfer page range shows 52 bytes +#if 0 + if ( Packet->TotalDataBufferLength != rndisHeader->MessageLength ) + { + PageUnmapVirtualAddress((void*)(ULONG_PTR)rndisHeader - Packet->PageBuffers[0].Offset); + + DPRINT_ERR(NETVSC, "invalid rndis message? (expected %u bytes got %u)...dropping this message!", + rndisHeader->MessageLength, Packet->TotalDataBufferLength); + DPRINT_EXIT(NETVSC); + return -1; + } +#endif + + if ((rndisHeader->NdisMessageType != REMOTE_NDIS_PACKET_MSG) && (rndisHeader->MessageLength > sizeof(RNDIS_MESSAGE))) + { + DPRINT_ERR(NETVSC, "incoming rndis message buffer overflow detected (got %u, max %u)...marking it an error!", + rndisHeader->MessageLength, sizeof(RNDIS_MESSAGE)); + } + + memcpy(&rndisMessage, rndisHeader, (rndisHeader->MessageLength > sizeof(RNDIS_MESSAGE))?sizeof(RNDIS_MESSAGE):rndisHeader->MessageLength); + + PageUnmapVirtualAddress((void*)(ULONG_PTR)rndisHeader - Packet->PageBuffers[0].Offset); + + DumpRndisMessage(&rndisMessage); + + switch (rndisMessage.NdisMessageType) + { + // data msg + case REMOTE_NDIS_PACKET_MSG: + RndisFilterReceiveData(rndisDevice, &rndisMessage, Packet); + break; + + // completion msgs + case REMOTE_NDIS_INITIALIZE_CMPLT: + case REMOTE_NDIS_QUERY_CMPLT: + case REMOTE_NDIS_SET_CMPLT: + //case REMOTE_NDIS_RESET_CMPLT: + //case REMOTE_NDIS_KEEPALIVE_CMPLT: + RndisFilterReceiveResponse(rndisDevice, &rndisMessage); + break; + + // notification msgs + case REMOTE_NDIS_INDICATE_STATUS_MSG: + RndisFilterReceiveIndicateStatus(rndisDevice, &rndisMessage); + break; + default: + DPRINT_ERR(NETVSC, "unhandled rndis message (type %u len %u)", rndisMessage.NdisMessageType, rndisMessage.MessageLength); + break; + } + + DPRINT_EXIT(NETVSC); + return 0; +} + + +static int +RndisFilterQueryDevice( + RNDIS_DEVICE *Device, + UINT32 Oid, + VOID *Result, + UINT32 *ResultSize + ) +{ + RNDIS_REQUEST *request; + UINT32 inresultSize = *ResultSize; + RNDIS_QUERY_REQUEST *query; + RNDIS_QUERY_COMPLETE *queryComplete; + int ret=0; + + DPRINT_ENTER(NETVSC); + + ASSERT(Result); + + *ResultSize = 0; + request = GetRndisRequest(Device, REMOTE_NDIS_QUERY_MSG, RNDIS_MESSAGE_SIZE(RNDIS_QUERY_REQUEST)); + if (!request) + { + ret = -1; + goto Cleanup; + } + + // Setup the rndis query + query = &request->RequestMessage.Message.QueryRequest; + query->Oid = Oid; + query->InformationBufferOffset = sizeof(RNDIS_QUERY_REQUEST); + query->InformationBufferLength = 0; + query->DeviceVcHandle = 0; + + ret = RndisFilterSendRequest(Device, request); + if (ret != 0) + { + goto Cleanup; + } + + WaitEventWait(request->WaitEvent); + + // Copy the response back + queryComplete = &request->ResponseMessage.Message.QueryComplete; + + if (queryComplete->InformationBufferLength > inresultSize) + { + ret = -1; + goto Cleanup; + } + + memcpy(Result, + (void*)((ULONG_PTR)queryComplete + queryComplete->InformationBufferOffset), + queryComplete->InformationBufferLength); + + *ResultSize = queryComplete->InformationBufferLength; + +Cleanup: + if (request) + { + PutRndisRequest(Device, request); + } + DPRINT_EXIT(NETVSC); + + return ret; +} + +static inline int +RndisFilterQueryDeviceMac( + RNDIS_DEVICE *Device + ) +{ + UINT32 size=HW_MACADDR_LEN; + + return RndisFilterQueryDevice(Device, + RNDIS_OID_802_3_PERMANENT_ADDRESS, + Device->HwMacAddr, + &size); +} + +static inline int +RndisFilterQueryDeviceLinkStatus( + RNDIS_DEVICE *Device + ) +{ + UINT32 size=sizeof(UINT32); + + return RndisFilterQueryDevice(Device, + RNDIS_OID_GEN_MEDIA_CONNECT_STATUS, + &Device->LinkStatus, + &size); +} + +static int +RndisFilterSetPacketFilter( + RNDIS_DEVICE *Device, + UINT32 NewFilter + ) +{ + RNDIS_REQUEST *request; + RNDIS_SET_REQUEST *set; + RNDIS_SET_COMPLETE *setComplete; + UINT32 status; + int ret; + + DPRINT_ENTER(NETVSC); + + ASSERT(RNDIS_MESSAGE_SIZE(RNDIS_SET_REQUEST) + sizeof(UINT32) <= sizeof(RNDIS_MESSAGE)); + + request = GetRndisRequest(Device, REMOTE_NDIS_SET_MSG, RNDIS_MESSAGE_SIZE(RNDIS_SET_REQUEST) + sizeof(UINT32)); + if (!request) + { + ret = -1; + goto Cleanup; + } + + // Setup the rndis set + set = &request->RequestMessage.Message.SetRequest; + set->Oid = RNDIS_OID_GEN_CURRENT_PACKET_FILTER; + set->InformationBufferLength = sizeof(UINT32); + set->InformationBufferOffset = sizeof(RNDIS_SET_REQUEST); + + memcpy((void*)(ULONG_PTR)set + sizeof(RNDIS_SET_REQUEST), &NewFilter, sizeof(UINT32)); + + ret = RndisFilterSendRequest(Device, request); + if (ret != 0) + { + goto Cleanup; + } + + ret = WaitEventWaitEx(request->WaitEvent, 2000/*2sec*/); + if (!ret) + { + ret = -1; + DPRINT_ERR(NETVSC, "timeout before we got a set response..."); + // We cant deallocate the request since we may still receive a send completion for it. + goto Exit; + } + else + { + if (ret > 0) + { + ret = 0; + } + setComplete = &request->ResponseMessage.Message.SetComplete; + status = setComplete->Status; + } + +Cleanup: + if (request) + { + PutRndisRequest(Device, request); + } +Exit: + DPRINT_EXIT(NETVSC); + + return ret; +} + +int +RndisFilterInit( + NETVSC_DRIVER_OBJECT *Driver + ) +{ + DPRINT_ENTER(NETVSC); + + DPRINT_DBG(NETVSC, "sizeof(RNDIS_FILTER_PACKET) == %d", sizeof(RNDIS_FILTER_PACKET)); + + Driver->RequestExtSize = sizeof(RNDIS_FILTER_PACKET); + Driver->AdditionalRequestPageBufferCount = 1; // For rndis header + + //Driver->Context = rndisDriver; + + memset(&gRndisFilter, 0, sizeof(RNDIS_FILTER_DRIVER_OBJECT)); + + /*rndisDriver->Driver = Driver; + + ASSERT(Driver->OnLinkStatusChanged); + rndisDriver->OnLinkStatusChanged = Driver->OnLinkStatusChanged;*/ + + // Save the original dispatch handlers before we override it + gRndisFilter.InnerDriver.Base.OnDeviceAdd = Driver->Base.OnDeviceAdd; + gRndisFilter.InnerDriver.Base.OnDeviceRemove = Driver->Base.OnDeviceRemove; + gRndisFilter.InnerDriver.Base.OnCleanup = Driver->Base.OnCleanup; + + ASSERT(Driver->OnSend); + ASSERT(Driver->OnReceiveCallback); + gRndisFilter.InnerDriver.OnSend = Driver->OnSend; + gRndisFilter.InnerDriver.OnReceiveCallback = Driver->OnReceiveCallback; + gRndisFilter.InnerDriver.OnLinkStatusChanged = Driver->OnLinkStatusChanged; + + // Override + Driver->Base.OnDeviceAdd = RndisFilterOnDeviceAdd; + Driver->Base.OnDeviceRemove = RndisFilterOnDeviceRemove; + Driver->Base.OnCleanup = RndisFilterOnCleanup; + Driver->OnSend = RndisFilterOnSend; + Driver->OnOpen = RndisFilterOnOpen; + Driver->OnClose = RndisFilterOnClose; + //Driver->QueryLinkStatus = RndisFilterQueryDeviceLinkStatus; + Driver->OnReceiveCallback = RndisFilterOnReceive; + + DPRINT_EXIT(NETVSC); + + return 0; +} + +static int +RndisFilterInitDevice( + RNDIS_DEVICE *Device + ) +{ + RNDIS_REQUEST *request; + RNDIS_INITIALIZE_REQUEST *init; + RNDIS_INITIALIZE_COMPLETE *initComplete; + UINT32 status; + int ret; + + DPRINT_ENTER(NETVSC); + + request = GetRndisRequest(Device, REMOTE_NDIS_INITIALIZE_MSG, RNDIS_MESSAGE_SIZE(RNDIS_INITIALIZE_REQUEST)); + if (!request) + { + ret = -1; + goto Cleanup; + } + + // Setup the rndis set + init = &request->RequestMessage.Message.InitializeRequest; + init->MajorVersion = RNDIS_MAJOR_VERSION; + init->MinorVersion = RNDIS_MINOR_VERSION; + init->MaxTransferSize = 2048; // FIXME: Use 1536 - rounded ethernet frame size + + Device->State = RNDIS_DEV_INITIALIZING; + + ret = RndisFilterSendRequest(Device, request); + if (ret != 0) + { + Device->State = RNDIS_DEV_UNINITIALIZED; + goto Cleanup; + } + + WaitEventWait(request->WaitEvent); + + initComplete = &request->ResponseMessage.Message.InitializeComplete; + status = initComplete->Status; + if (status == RNDIS_STATUS_SUCCESS) + { + Device->State = RNDIS_DEV_INITIALIZED; + ret = 0; + } + else + { + Device->State = RNDIS_DEV_UNINITIALIZED; + ret = -1; + } + +Cleanup: + if (request) + { + PutRndisRequest(Device, request); + } + DPRINT_EXIT(NETVSC); + + return ret; +} + +static void +RndisFilterHaltDevice( + RNDIS_DEVICE *Device + ) +{ + RNDIS_REQUEST *request; + RNDIS_HALT_REQUEST *halt; + + DPRINT_ENTER(NETVSC); + + // Attempt to do a rndis device halt + request = GetRndisRequest(Device, REMOTE_NDIS_HALT_MSG, RNDIS_MESSAGE_SIZE(RNDIS_HALT_REQUEST)); + if (!request) + { + goto Cleanup; + } + + // Setup the rndis set + halt = &request->RequestMessage.Message.HaltRequest; + halt->RequestId = InterlockedIncrement((int*)&Device->NewRequestId); + + // Ignore return since this msg is optional. + RndisFilterSendRequest(Device, request); + + Device->State = RNDIS_DEV_UNINITIALIZED; + +Cleanup: + if (request) + { + PutRndisRequest(Device, request); + } + DPRINT_EXIT(NETVSC); + return; +} + + +static int +RndisFilterOpenDevice( + RNDIS_DEVICE *Device + ) +{ + int ret=0; + + DPRINT_ENTER(NETVSC); + + if (Device->State != RNDIS_DEV_INITIALIZED) + return 0; + + ret = RndisFilterSetPacketFilter(Device, NDIS_PACKET_TYPE_BROADCAST|NDIS_PACKET_TYPE_DIRECTED); + if (ret == 0) + { + Device->State = RNDIS_DEV_DATAINITIALIZED; + } + + DPRINT_EXIT(NETVSC); + return ret; +} + +static int +RndisFilterCloseDevice( + RNDIS_DEVICE *Device + ) +{ + int ret; + + DPRINT_ENTER(NETVSC); + + if (Device->State != RNDIS_DEV_DATAINITIALIZED) + return 0; + + ret = RndisFilterSetPacketFilter(Device, 0); + if (ret == 0) + { + Device->State = RNDIS_DEV_INITIALIZED; + } + + DPRINT_EXIT(NETVSC); + + return ret; +} + + +int +RndisFilterOnDeviceAdd( + DEVICE_OBJECT *Device, + void *AdditionalInfo + ) +{ + int ret; + NETVSC_DEVICE *netDevice; + RNDIS_DEVICE *rndisDevice; + NETVSC_DEVICE_INFO *deviceInfo = (NETVSC_DEVICE_INFO*)AdditionalInfo; + + DPRINT_ENTER(NETVSC); + + //rndisDevice = MemAlloc(sizeof(RNDIS_DEVICE)); + rndisDevice = GetRndisDevice(); + if (!rndisDevice) + { + DPRINT_EXIT(NETVSC); + return -1; + } + + DPRINT_DBG(NETVSC, "rndis device object allocated - %p", rndisDevice); + + // Let the inner driver handle this first to create the netvsc channel + // NOTE! Once the channel is created, we may get a receive callback + // (RndisFilterOnReceive()) before this call is completed + ret = gRndisFilter.InnerDriver.Base.OnDeviceAdd(Device, AdditionalInfo); + if (ret != 0) + { + PutRndisDevice(rndisDevice); + DPRINT_EXIT(NETVSC); + return ret; + } + + // + // Initialize the rndis device + // + netDevice = (NETVSC_DEVICE*)Device->Extension; + ASSERT(netDevice); + ASSERT(netDevice->Device); + + netDevice->Extension = rndisDevice; + rndisDevice->NetDevice = netDevice; + + // Send the rndis initialization message + ret = RndisFilterInitDevice(rndisDevice); + if (ret != 0) + { + // TODO: If rndis init failed, we will need to shut down the channel + } + + // Get the mac address + ret = RndisFilterQueryDeviceMac(rndisDevice); + if (ret != 0) + { + // TODO: shutdown rndis device and the channel + } + + DPRINT_INFO(NETVSC, "Device 0x%p mac addr %02x%02x%02x%02x%02x%02x", + rndisDevice, + rndisDevice->HwMacAddr[0], + rndisDevice->HwMacAddr[1], + rndisDevice->HwMacAddr[2], + rndisDevice->HwMacAddr[3], + rndisDevice->HwMacAddr[4], + rndisDevice->HwMacAddr[5]); + + memcpy(deviceInfo->MacAddr, rndisDevice->HwMacAddr, HW_MACADDR_LEN); + + RndisFilterQueryDeviceLinkStatus(rndisDevice); + + deviceInfo->LinkState = rndisDevice->LinkStatus; + DPRINT_INFO(NETVSC, "Device 0x%p link state %s", rndisDevice, ((deviceInfo->LinkState)?("down"):("up"))); + + DPRINT_EXIT(NETVSC); + + return ret; +} + + +static int +RndisFilterOnDeviceRemove( + DEVICE_OBJECT *Device + ) +{ + NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension; + RNDIS_DEVICE *rndisDevice = (RNDIS_DEVICE*)netDevice->Extension; + + DPRINT_ENTER(NETVSC); + + // Halt and release the rndis device + RndisFilterHaltDevice(rndisDevice); + + PutRndisDevice(rndisDevice); + netDevice->Extension = NULL; + + // Pass control to inner driver to remove the device + gRndisFilter.InnerDriver.Base.OnDeviceRemove(Device); + + DPRINT_EXIT(NETVSC); + + return 0; +} + + +static void +RndisFilterOnCleanup( + DRIVER_OBJECT *Driver + ) +{ + DPRINT_ENTER(NETVSC); + + DPRINT_EXIT(NETVSC); +} + +static int +RndisFilterOnOpen( + DEVICE_OBJECT *Device + ) +{ + int ret; + NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension; + + DPRINT_ENTER(NETVSC); + + ASSERT(netDevice); + ret = RndisFilterOpenDevice((RNDIS_DEVICE*)netDevice->Extension); + + DPRINT_EXIT(NETVSC); + + return ret; +} + +static int +RndisFilterOnClose( + DEVICE_OBJECT *Device + ) +{ + int ret; + NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension; + + DPRINT_ENTER(NETVSC); + + ASSERT(netDevice); + ret = RndisFilterCloseDevice((RNDIS_DEVICE*)netDevice->Extension); + + DPRINT_EXIT(NETVSC); + + return ret; +} + + +static int +RndisFilterOnSend( + DEVICE_OBJECT *Device, + NETVSC_PACKET *Packet + ) +{ + int ret=0; + RNDIS_FILTER_PACKET *filterPacket; + RNDIS_MESSAGE *rndisMessage; + RNDIS_PACKET *rndisPacket; + UINT32 rndisMessageSize; + + DPRINT_ENTER(NETVSC); + + // Add the rndis header + filterPacket = (RNDIS_FILTER_PACKET*)Packet->Extension; + ASSERT(filterPacket); + + memset(filterPacket, 0, sizeof(RNDIS_FILTER_PACKET)); + + rndisMessage = &filterPacket->Message; + rndisMessageSize = RNDIS_MESSAGE_SIZE(RNDIS_PACKET); + + rndisMessage->NdisMessageType = REMOTE_NDIS_PACKET_MSG; + rndisMessage->MessageLength = Packet->TotalDataBufferLength + rndisMessageSize; + + rndisPacket = &rndisMessage->Message.Packet; + rndisPacket->DataOffset = sizeof(RNDIS_PACKET); + rndisPacket->DataLength = Packet->TotalDataBufferLength; + + Packet->IsDataPacket = TRUE; + Packet->PageBuffers[0].Pfn = GetPhysicalAddress(rndisMessage) >> PAGE_SHIFT; + Packet->PageBuffers[0].Offset = (ULONG_PTR)rndisMessage & (PAGE_SIZE-1); + Packet->PageBuffers[0].Length = rndisMessageSize; + + // Save the packet send completion and context + filterPacket->OnCompletion = Packet->Completion.Send.OnSendCompletion; + filterPacket->CompletionContext = Packet->Completion.Send.SendCompletionContext; + + // Use ours + Packet->Completion.Send.OnSendCompletion = RndisFilterOnSendCompletion; + Packet->Completion.Send.SendCompletionContext = filterPacket; + + ret = gRndisFilter.InnerDriver.OnSend(Device, Packet); + if (ret != 0) + { + // Reset the completion to originals to allow retries from above + Packet->Completion.Send.OnSendCompletion = filterPacket->OnCompletion; + Packet->Completion.Send.SendCompletionContext = filterPacket->CompletionContext; + } + + DPRINT_EXIT(NETVSC); + + return ret; +} + +static void +RndisFilterOnSendCompletion( + void *Context) +{ + RNDIS_FILTER_PACKET *filterPacket = (RNDIS_FILTER_PACKET *)Context; + + DPRINT_ENTER(NETVSC); + + // Pass it back to the original handler + filterPacket->OnCompletion(filterPacket->CompletionContext); + + DPRINT_EXIT(NETVSC); +} + + +static void +RndisFilterOnSendRequestCompletion( + void *Context + ) +{ + DPRINT_ENTER(NETVSC); + + // Noop + DPRINT_EXIT(NETVSC); +} --- /dev/null +++ b/drivers/staging/hv/RndisFilter.h @@ -0,0 +1,61 @@ +/* + * + * Copyright (c) 2009, Microsoft Corporation. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along with + * this program; if not, write to the Free Software Foundation, Inc., 59 Temple + * Place - Suite 330, Boston, MA 02111-1307 USA. + * + * Authors: + * Haiyang Zhang + * Hank Janssen + * + */ + + +#ifndef _RNDISFILTER_H_ +#define _RNDISFILTER_H_ + +#define __struct_bcount(x) + +#include "osd.h" +#include "NetVsc.h" + +#include "rndis.h" + +#define RNDIS_HEADER_SIZE (sizeof(RNDIS_MESSAGE) - sizeof(RNDIS_MESSAGE_CONTAINER)) + +#define NDIS_PACKET_TYPE_DIRECTED 0x00000001 +#define NDIS_PACKET_TYPE_MULTICAST 0x00000002 +#define NDIS_PACKET_TYPE_ALL_MULTICAST 0x00000004 +#define NDIS_PACKET_TYPE_BROADCAST 0x00000008 +#define NDIS_PACKET_TYPE_SOURCE_ROUTING 0x00000010 +#define NDIS_PACKET_TYPE_PROMISCUOUS 0x00000020 +#define NDIS_PACKET_TYPE_SMT 0x00000040 +#define NDIS_PACKET_TYPE_ALL_LOCAL 0x00000080 +#define NDIS_PACKET_TYPE_GROUP 0x00000100 +#define NDIS_PACKET_TYPE_ALL_FUNCTIONAL 0x00000200 +#define NDIS_PACKET_TYPE_FUNCTIONAL 0x00000400 +#define NDIS_PACKET_TYPE_MAC_FRAME 0x00000800 + + + +// +// Interface +// +int +RndisFilterInit( + NETVSC_DRIVER_OBJECT *Driver + ); + + +#endif // _RNDISFILTER_H_