The Gaudi Framework  master (181af51f)
Loading...
Searching...
No Matches
CUDAStream.cpp
Go to the documentation of this file.
1/***********************************************************************************\
2* (c) Copyright 2023-2024 CERN for the benefit of the LHCb and ATLAS collaborations *
3* *
4* This software is distributed under the terms of the Apache version 2 licence, *
5* copied verbatim in the file "LICENSE". *
6* *
7* In applying this licence, CERN does not waive the privileges and immunities *
8* granted to it by virtue of its status as an Intergovernmental Organization *
9* or submit itself to any jurisdiction. *
10\***********************************************************************************/
11
12// Gaudi
15
16// CUDA
17#ifndef __CUDACC__
18# include <cuda_runtime.h>
19#endif
20
21// Others
22#include <boost/fiber/cuda/waitfor.hpp>
23#include <boost/fiber/mutex.hpp>
24#include <boost/fiber/recursive_mutex.hpp>
25
26// Standard Library
27#include <cstdio>
28#include <deque>
29#include <format>
30#include <mutex>
31#include <string>
32
33namespace Gaudi::CUDA {
34 namespace {
35 class StreamList {
36 using Mutex_t = std::recursive_mutex;
37 using Stream_t = cudaStream_t;
38
39 private:
40 std::deque<Stream_t> queue;
41 Mutex_t queue_mtx;
42
43 public:
45 void push( const Stream_t& s ) {
46 std::unique_lock lck( queue_mtx );
47 queue.push_back( s );
48 }
49
52 bool pop( Stream_t& s ) {
53 std::unique_lock lck( queue_mtx );
54 if ( queue.empty() ) { return false; }
55 s = queue.front();
56 queue.pop_front();
57 return true;
58 }
59
60 ~StreamList() {
61 Stream_t s;
62 while ( pop( s ) ) {
63 cudaStreamDestroy( s );
64 s = nullptr;
65 }
66 }
67 };
68 StreamList available_streams{};
69 std::string err_fmt( cudaError_t err, std::string file, int line ) {
70 const char* errname = cudaGetErrorName( err );
71 const char* errstr = cudaGetErrorString( err );
72 std::string errmsg =
73 std::format( "Encountered CUDA error {} [{}]: {} on {}:{}", errname, int( err ), errstr, file, line );
74 return errmsg;
75 }
76 } // namespace
78 : m_stream( nullptr ), m_parent( parent ), m_dependents( 0 ) {
79 if ( !available_streams.pop( m_stream ) ) {
80 cudaError_t err = cudaStreamCreate( &m_stream );
81 if ( err != cudaSuccess ) {
82 cudaGetLastError();
83 throw GaudiException( err_fmt( err, __FILE__, __LINE__ ), "CUDAStreamException", StatusCode::FAILURE );
84 }
85 err = cudaStreamSynchronize( m_stream );
86 if ( err != cudaSuccess ) {
87 cudaGetLastError();
88 throw GaudiException( err_fmt( err, __FILE__, __LINE__ ), "CUDAStreamException", StatusCode::FAILURE );
89 }
90 }
91 }
92
94 if ( m_dependents != 0 ) {
95 m_parent->error() << std::format( "Stream destroyed before all its dependents ({} remaining)", m_dependents )
96 << endmsg;
97 }
98 if ( await().isFailure() ) { m_parent->error() << "Error in Stream destructor" << endmsg; }
99 available_streams.push( m_stream );
100 }
101
103 auto res = boost::fibers::cuda::waitfor_all( m_stream );
104 cudaError_t temp_error = std::get<1>( res );
105 if ( ( temp_error ) != cudaSuccess ) {
106 cudaGetLastError();
107 std::string errmsg = err_fmt( temp_error, __FILE__, __LINE__ );
108 m_parent->error() << errmsg << endmsg;
109 return StatusCode::FAILURE;
110 }
111 return m_parent->restoreAfterSuspend();
112 }
113} // namespace Gaudi::CUDA
MsgStream & endmsg(MsgStream &s)
MsgStream Modifier: endmsg. Calls the output method of the MsgStream.
Definition MsgStream.h:198
Base class for asynchronous algorithms.
const Gaudi::AsynchronousAlgorithm * parent()
Access the parent algorithm.
Definition CUDAStream.h:36
StatusCode await()
Yield fiber until stream is done.
Stream(const Gaudi::AsynchronousAlgorithm *parent)
Create a new Stream. Should happen once per algorithm.
cudaStream_t m_stream
Definition CUDAStream.h:23
const Gaudi::AsynchronousAlgorithm * m_parent
Definition CUDAStream.h:24
Define general base for Gaudi exception.
This class is used for returning status codes from appropriate routines.
Definition StatusCode.h:64
constexpr static const auto FAILURE
Definition StatusCode.h:100