Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
446 views
in Technique[技术] by (71.8m points)

.net - Implementing extension method WebRequest.GetResponseAsync with support for CancellationToken

The idea here is simple, but the implementation has some interesting nuances. This is the signature of the extension method I would like to implement in .NET 4.

public static Task<WebResponse> GetResponseAsync(this WebRequest request, CancellationToken token);

Here is my initial implementation. From what I've read, the web request might need to be cancelled due to a timeout. In addition to the support described on that page, I want to properly call request.Abort() if cancellation is requested via the CancellationToken.

public static Task<WebResponse> GetResponseAsync(this WebRequest request, CancellationToken token)
{
    if (request == null)
        throw new ArgumentNullException("request");

    return Task.Factory.FromAsync<WebRequest, CancellationToken, WebResponse>(BeginGetResponse, request.EndGetResponse, request, token, null);
}

private static IAsyncResult BeginGetResponse(WebRequest request, CancellationToken token, AsyncCallback callback, object state)
{
    IAsyncResult asyncResult = request.BeginGetResponse(callback, state);
    if (!asyncResult.IsCompleted)
    {
        if (request.Timeout != Timeout.Infinite)
            ThreadPool.RegisterWaitForSingleObject(asyncResult.AsyncWaitHandle, WebRequestTimeoutCallback, request, request.Timeout, true);
        if (token != CancellationToken.None)
            ThreadPool.RegisterWaitForSingleObject(token.WaitHandle, WebRequestCancelledCallback, Tuple.Create(request, token), Timeout.Infinite, true);
    }

    return asyncResult;
}

private static void WebRequestTimeoutCallback(object state, bool timedOut)
{
    if (timedOut)
    {
        WebRequest request = state as WebRequest;
        if (request != null)
            request.Abort();
    }
}

private static void WebRequestCancelledCallback(object state, bool timedOut)
{
    Tuple<WebRequest, CancellationToken> data = state as Tuple<WebRequest, CancellationToken>;
    if (data != null && data.Item2.IsCancellationRequested)
    {
        data.Item1.Abort();
    }
}

My question is simple yet challenging. Will this implementation actually behave as expected when used with the TPL?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Will this implementation actually behave as expected when used with the TPL?

No.

  1. It will not flag the Task<T> result as cancelled, so the behavior will not be exactly as expected.
  2. In the event of a timeout, the WebException contained in the AggregateException reported by Task.Exception will have the status WebExceptionStatus.RequestCanceled. It should instead be WebExceptionStatus.Timeout.

I would actually recommend using TaskCompletionSource<T> to implement this. This allows you to write the code without making your own APM style methods:

public static Task<WebResponse> GetResponseAsync(this WebRequest request, CancellationToken token)
{
    if (request == null)
        throw new ArgumentNullException("request");

    bool timeout = false;
    TaskCompletionSource<WebResponse> completionSource = new TaskCompletionSource<WebResponse>();

    AsyncCallback completedCallback =
        result =>
        {
            try
            {
                completionSource.TrySetResult(request.EndGetResponse(result));
            }
            catch (WebException ex)
            {
                if (timeout)
                    completionSource.TrySetException(new WebException("No response was received during the time-out period for a request.", WebExceptionStatus.Timeout));
                else if (token.IsCancellationRequested)
                    completionSource.TrySetCanceled();
                else
                    completionSource.TrySetException(ex);
            }
            catch (Exception ex)
            {
                completionSource.TrySetException(ex);
            }
        };

    IAsyncResult asyncResult = request.BeginGetResponse(completedCallback, null);
    if (!asyncResult.IsCompleted)
    {
        if (request.Timeout != Timeout.Infinite)
        {
            WaitOrTimerCallback timedOutCallback =
                (object state, bool timedOut) =>
                {
                    if (timedOut)
                    {
                        timeout = true;
                        request.Abort();
                    }
                };

            ThreadPool.RegisterWaitForSingleObject(asyncResult.AsyncWaitHandle, timedOutCallback, null, request.Timeout, true);
        }

        if (token != CancellationToken.None)
        {
            WaitOrTimerCallback cancelledCallback =
                (object state, bool timedOut) =>
                {
                    if (token.IsCancellationRequested)
                        request.Abort();
                };

            ThreadPool.RegisterWaitForSingleObject(token.WaitHandle, cancelledCallback, null, Timeout.Infinite, true);
        }
    }

    return completionSource.Task;
}

The advantage here is that your Task<T> result will work fully as expected (will be flagged as canceled, or raise the same exception with timeout info as synchronous version, etc). This also avoids the overhead of using Task.Factory.FromAsync, since you're already handling most of the difficult work involved there yourself.


Addendum by 280Z28

Here is a unit test showing proper operation for the method above.

[TestClass]
public class AsyncWebRequestTests
{
    [TestMethod]
    public void TestAsyncWebRequest()
    {
        Uri uri = new Uri("http://google.com");
        WebRequest request = HttpWebRequest.Create(uri);
        Task<WebResponse> response = request.GetResponseAsync();
        response.Wait();
    }

    [TestMethod]
    public void TestAsyncWebRequestTimeout()
    {
        Uri uri = new Uri("http://google.com");
        WebRequest request = HttpWebRequest.Create(uri);
        request.Timeout = 0;
        Task<WebResponse> response = request.GetResponseAsync();
        try
        {
            response.Wait();
            Assert.Fail("Expected an exception");
        }
        catch (AggregateException exception)
        {
            Assert.AreEqual(TaskStatus.Faulted, response.Status);

            ReadOnlyCollection<Exception> exceptions = exception.InnerExceptions;
            Assert.AreEqual(1, exceptions.Count);
            Assert.IsInstanceOfType(exceptions[0], typeof(WebException));

            WebException webException = (WebException)exceptions[0];
            Assert.AreEqual(WebExceptionStatus.Timeout, webException.Status);
        }
    }

    [TestMethod]
    public void TestAsyncWebRequestCancellation()
    {
        Uri uri = new Uri("http://google.com");
        WebRequest request = HttpWebRequest.Create(uri);
        CancellationTokenSource cancellationTokenSource = new CancellationTokenSource();
        Task<WebResponse> response = request.GetResponseAsync(cancellationTokenSource.Token);
        cancellationTokenSource.Cancel();
        try
        {
            response.Wait();
            Assert.Fail("Expected an exception");
        }
        catch (AggregateException exception)
        {
            Assert.AreEqual(TaskStatus.Canceled, response.Status);

            ReadOnlyCollection<Exception> exceptions = exception.InnerExceptions;
            Assert.AreEqual(1, exceptions.Count);
            Assert.IsInstanceOfType(exceptions[0], typeof(OperationCanceledException));
        }
    }

    [TestMethod]
    public void TestAsyncWebRequestError()
    {
        Uri uri = new Uri("http://google.com/fail");
        WebRequest request = HttpWebRequest.Create(uri);
        Task<WebResponse> response = request.GetResponseAsync();
        try
        {
            response.Wait();
            Assert.Fail("Expected an exception");
        }
        catch (AggregateException exception)
        {
            Assert.AreEqual(TaskStatus.Faulted, response.Status);

            ReadOnlyCollection<Exception> exceptions = exception.InnerExceptions;
            Assert.AreEqual(1, exceptions.Count);
            Assert.IsInstanceOfType(exceptions[0], typeof(WebException));

            WebException webException = (WebException)exceptions[0];
            Assert.AreEqual(HttpStatusCode.NotFound, ((HttpWebResponse)webException.Response).StatusCode);
        }
    }
}

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...